Search in sources :

Example 1 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMEN method reportGeneral.

private static void reportGeneral(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
    System.out.println("============================================================");
    System.out.println("computing other predictor-independent metrics");
    String output = config.getString("output.dir");
    File labelProbFile = Paths.get(output, "test_predictions", "label_probabilities.txt").toFile();
    double labelProbThreshold = config.getDouble("report.labelProbThreshold");
    try (BufferedWriter br = new BufferedWriter(new FileWriter(labelProbFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(CBMInspector.topLabels(cbm, dataSet.getRow(i), labelProbThreshold));
            br.newLine();
        }
    }
    System.out.println("individual label probabilities are saved to " + labelProbFile.getAbsolutePath());
    List<Integer> unobservedLabels = Arrays.stream(FileUtils.readFileToString(new File(output, "unobserved_labels.txt")).split(",")).map(s -> s.trim()).filter(s -> !s.isEmpty()).map(s -> Integer.parseInt(s)).collect(Collectors.toList());
    // Here we do not use approximation
    double[] logLikelihoods = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(i -> cbm.predictLogAssignmentProb(dataSet.getRow(i), dataSet.getMultiLabels()[i])).toArray();
    double average = IntStream.range(0, dataSet.getNumDataPoints()).filter(i -> !containsNovelClass(dataSet.getMultiLabels()[i], unobservedLabels)).mapToDouble(i -> logLikelihoods[i]).average().getAsDouble();
    File logLikelihoodFile = Paths.get(output, "test_predictions", "ground_truth_log_likelihood.txt").toFile();
    FileUtils.writeStringToFile(logLikelihoodFile, PrintUtil.toMutipleLines(logLikelihoods));
    System.out.println("individual log likelihood of the test ground truth label set is written to " + logLikelihoodFile.getAbsolutePath());
    System.out.println("average log likelihood of the test ground truth label sets = " + average);
    if (!unobservedLabels.isEmpty()) {
        System.out.println("This is computed by ignoring test instances with new labels unobserved during training");
        System.out.println("The following labels do not actually appear in the training set and therefore cannot be learned:");
        System.out.println(ListUtil.toSimpleString(unobservedLabels));
    }
}
Also used : IntStream(java.util.stream.IntStream) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) java.util(java.util) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) FileWriter(java.io.FileWriter) File(java.io.File) BufferedWriter(java.io.BufferedWriter)

Example 2 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMEN method reportAccPrediction.

private static void reportAccPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
    System.out.println("============================================================");
    System.out.println("Making predictions on test set with the instance set accuracy optimal predictor");
    String output = config.getString("output.dir");
    AccPredictor accPredictor = new AccPredictor(cbm);
    accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
    MultiLabel[] predictions = accPredictor.predict(dataSet);
    MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
    System.out.println("test performance with the instance set accuracy optimal predictor");
    System.out.println(mlMeasures);
    File performanceFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "performance.txt").toFile();
    FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
    System.out.println("test performance is saved to " + performanceFile.toString());
    // Here we do not use approximation
    double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
    File predictionFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "predictions.txt").toFile();
    try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(predictions[i].toString());
            br.write(":");
            br.write("" + setProbs[i]);
            br.newLine();
        }
    }
    System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
    System.out.println("============================================================");
}
Also used : IntStream(java.util.stream.IntStream) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) java.util(java.util) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) FileWriter(java.io.FileWriter) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) File(java.io.File) BufferedWriter(java.io.BufferedWriter)

Example 3 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMLR method reportHammingPrediction.

private static void reportHammingPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
    System.out.println("============================================================");
    System.out.println("Making predictions on test set with the instance Hamming loss optimal predictor");
    String output = config.getString("output.dir");
    MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
    marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
    MultiLabel[] predictions = marginalPredictor.predict(dataSet);
    MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
    System.out.println("test performance with the instance Hamming loss optimal predictor");
    System.out.println(mlMeasures);
    File performanceFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "performance.txt").toFile();
    FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
    System.out.println("test performance is saved to " + performanceFile.toString());
    // Here we do not use approximation
    double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
    File predictionFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "predictions.txt").toFile();
    try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(predictions[i].toString());
            br.write(":");
            br.write("" + setProbs[i]);
            br.newLine();
        }
    }
    System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
    System.out.println("============================================================");
}
Also used : IntStream(java.util.stream.IntStream) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) java.util(java.util) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) FileWriter(java.io.FileWriter) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) File(java.io.File) BufferedWriter(java.io.BufferedWriter)

Example 4 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMLR method reportAccPrediction.

private static void reportAccPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
    System.out.println("============================================================");
    System.out.println("Making predictions on test set with the instance set accuracy optimal predictor");
    String output = config.getString("output.dir");
    AccPredictor accPredictor = new AccPredictor(cbm);
    accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
    MultiLabel[] predictions = accPredictor.predict(dataSet);
    MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
    System.out.println("test performance with the instance set accuracy optimal predictor");
    System.out.println(mlMeasures);
    File performanceFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "performance.txt").toFile();
    FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
    System.out.println("test performance is saved to " + performanceFile.toString());
    // Here we do not use approximation
    double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
    File predictionFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "predictions.txt").toFile();
    try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(predictions[i].toString());
            br.write(":");
            br.write("" + setProbs[i]);
            br.newLine();
        }
    }
    System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
    System.out.println("============================================================");
}
Also used : IntStream(java.util.stream.IntStream) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) java.util(java.util) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) FileWriter(java.io.FileWriter) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) File(java.io.File) BufferedWriter(java.io.BufferedWriter)

Example 5 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class App1 method ngramSelection.

private static void ngramSelection(Config config, MultiLabelIndex index, String docFilter, Logger logger) throws Exception {
    logger.info("start ngram selection");
    File metaDataFolder = new File(config.getString("output.folder"), "meta_data");
    FeatureLoader.MatchScoreType matchScoreType;
    String matchScoreTypeString = config.getString("train.feature.ngram.matchScoreType");
    String[] indexIds = getDocsForSplitFromQuery(index, config.getString("train.splitQuery"));
    IdTranslator idTranslator = loadIdTranslator(indexIds);
    LabelTranslator labelTranslator = (LabelTranslator) Serialization.deserialize(new File(metaDataFolder, "label_translator.ser"));
    switch(matchScoreTypeString) {
        case "es_original":
            matchScoreType = FeatureLoader.MatchScoreType.ES_ORIGINAL;
            break;
        case "binary":
            matchScoreType = FeatureLoader.MatchScoreType.BINARY;
            break;
        case "frequency":
            matchScoreType = FeatureLoader.MatchScoreType.FREQUENCY;
            break;
        case "tfifl":
            matchScoreType = FeatureLoader.MatchScoreType.TFIFL;
            break;
        default:
            throw new IllegalArgumentException("unknown ngramMatchScoreType");
    }
    double[][] labels = loadLabels(config, index, idTranslator, labelTranslator);
    int numLabels = labels.length;
    int toKeep = config.getInt("train.feature.ngram.selectPerLabel");
    List<BoundedBlockPriorityQueue<Pair<Ngram, Double>>> queues = new ArrayList<>();
    Comparator<Pair<Ngram, Double>> comparator = Comparator.comparing(p -> p.getSecond());
    for (int l = 0; l < numLabels; l++) {
        queues.add(new BoundedBlockPriorityQueue<>(toKeep, comparator));
    }
    FeatureList featureList = (FeatureList) Serialization.deserialize(new File(metaDataFolder, "feature_list.ser"));
    featureList.getAll().stream().parallel().filter(feature -> feature instanceof Ngram).map(feature -> (Ngram) feature).filter(ngram -> ngram.getN() > 1).forEach(ngram -> {
        double[] scores = StumpSelector.scores(index, labels, ngram, idTranslator, matchScoreType, docFilter);
        for (int l = 0; l < numLabels; l++) {
            queues.get(l).add(new Pair<>(ngram, scores[l]));
        }
    });
    Set<Ngram> kept = new HashSet<>();
    StringBuilder stringBuilder = new StringBuilder();
    for (int l = 0; l < numLabels; l++) {
        stringBuilder.append("-------------------------").append("\n");
        stringBuilder.append(labelTranslator.toExtLabel(l)).append(":").append("\n");
        BoundedBlockPriorityQueue<Pair<Ngram, Double>> queue = queues.get(l);
        while (queue.size() > 0) {
            Ngram ngram = queue.poll().getFirst();
            kept.add(ngram);
            stringBuilder.append(ngram.getNgram()).append(", ");
        }
        stringBuilder.append("\n");
    }
    File selectionFile = new File(metaDataFolder, "selected_ngrams.txt");
    FileUtils.writeStringToFile(selectionFile, stringBuilder.toString());
    logger.info("finish ngram selection");
    logger.info("selected ngrams are written to " + selectionFile.getAbsolutePath());
    // after feature selection, overwrite the feature_list.ser file; rename old files
    FeatureList selectedFeatures = new FeatureList();
    for (Feature feature : featureList.getAll()) {
        if (!(feature instanceof Ngram)) {
            selectedFeatures.add(feature);
        }
        if ((feature instanceof Ngram) && ((Ngram) feature).getN() == 1) {
            selectedFeatures.add(feature);
        }
        if ((feature instanceof Ngram) && ((Ngram) feature).getN() > 1 && kept.contains(feature)) {
            selectedFeatures.add(feature);
        }
    }
    FileUtils.copyFile(new File(metaDataFolder, "feature_list.ser"), new File(metaDataFolder, "feature_list_all.ser"));
    FileUtils.copyFile(new File(metaDataFolder, "feature_list.txt"), new File(metaDataFolder, "feature_list_all.txt"));
    Serialization.serialize(selectedFeatures, new File(metaDataFolder, "feature_list.ser"));
    try (BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(metaDataFolder, "feature_list.txt")))) {
        for (Feature feature : selectedFeatures.getAll()) {
            bufferedWriter.write(feature.toString());
            bufferedWriter.newLine();
        }
    }
}
Also used : java.util.logging(java.util.logging) java.util(java.util) BoundedBlockPriorityQueue(edu.neu.ccs.pyramid.util.BoundedBlockPriorityQueue) Multiset(com.google.common.collect.Multiset) NgramEnumerator(edu.neu.ccs.pyramid.feature_extraction.NgramEnumerator) edu.neu.ccs.pyramid.feature(edu.neu.ccs.pyramid.feature) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) FeatureLoader(edu.neu.ccs.pyramid.elasticsearch.FeatureLoader) Terms(org.elasticsearch.search.aggregations.bucket.terms.Terms) BufferedWriter(java.io.BufferedWriter) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) FileWriter(java.io.FileWriter) FileUtils(org.apache.commons.io.FileUtils) IOException(java.io.IOException) Collectors(java.util.stream.Collectors) File(java.io.File) MultiLabelIndex(edu.neu.ccs.pyramid.elasticsearch.MultiLabelIndex) ConcurrentHashMultiset(com.google.common.collect.ConcurrentHashMultiset) ESIndex(edu.neu.ccs.pyramid.elasticsearch.ESIndex) NgramTemplate(edu.neu.ccs.pyramid.feature_extraction.NgramTemplate) Serialization(edu.neu.ccs.pyramid.util.Serialization) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) StumpSelector(edu.neu.ccs.pyramid.feature_extraction.StumpSelector) Pattern(java.util.regex.Pattern) FileWriter(java.io.FileWriter) BufferedWriter(java.io.BufferedWriter) Pair(edu.neu.ccs.pyramid.util.Pair) BoundedBlockPriorityQueue(edu.neu.ccs.pyramid.util.BoundedBlockPriorityQueue) FeatureLoader(edu.neu.ccs.pyramid.elasticsearch.FeatureLoader) File(java.io.File)

Aggregations

Config (edu.neu.ccs.pyramid.configuration.Config)119 File (java.io.File)68 Collectors (java.util.stream.Collectors)40 FileUtils (org.apache.commons.io.FileUtils)40 Paths (java.nio.file.Paths)39 IntStream (java.util.stream.IntStream)37 Pair (edu.neu.ccs.pyramid.util.Pair)36 Serialization (edu.neu.ccs.pyramid.util.Serialization)35 StopWatch (org.apache.commons.lang3.time.StopWatch)34 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)33 BufferedWriter (java.io.BufferedWriter)32 FileWriter (java.io.FileWriter)32 java.util (java.util)32 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)31 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)29 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)28 PrintUtil (edu.neu.ccs.pyramid.util.PrintUtil)26 edu.neu.ccs.pyramid.multilabel_classification.cbm (edu.neu.ccs.pyramid.multilabel_classification.cbm)25 ListUtil (edu.neu.ccs.pyramid.util.ListUtil)25 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)22