Search in sources :

Example 6 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMLR method reportAccPrediction.

private static void reportAccPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
    System.out.println("============================================================");
    System.out.println("Making predictions on test set with the instance set accuracy optimal predictor");
    String output = config.getString("output.dir");
    AccPredictor accPredictor = new AccPredictor(cbm);
    accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
    MultiLabel[] predictions = accPredictor.predict(dataSet);
    MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
    System.out.println("test performance with the instance set accuracy optimal predictor");
    System.out.println(mlMeasures);
    File performanceFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "performance.txt").toFile();
    FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
    System.out.println("test performance is saved to " + performanceFile.toString());
    // Here we do not use approximation
    double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
    File predictionFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "predictions.txt").toFile();
    try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(predictions[i].toString());
            br.write(":");
            br.write("" + setProbs[i]);
            br.newLine();
        }
    }
    System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
    System.out.println("============================================================");
}
Also used : IntStream(java.util.stream.IntStream) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) java.util(java.util) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) FileWriter(java.io.FileWriter) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) File(java.io.File) BufferedWriter(java.io.BufferedWriter)

Example 7 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class App1 method ngramSelection.

private static void ngramSelection(Config config, MultiLabelIndex index, String docFilter, Logger logger) throws Exception {
    logger.info("start ngram selection");
    File metaDataFolder = new File(config.getString("output.folder"), "meta_data");
    FeatureLoader.MatchScoreType matchScoreType;
    String matchScoreTypeString = config.getString("train.feature.ngram.matchScoreType");
    String[] indexIds = getDocsForSplitFromQuery(index, config.getString("train.splitQuery"));
    IdTranslator idTranslator = loadIdTranslator(indexIds);
    LabelTranslator labelTranslator = (LabelTranslator) Serialization.deserialize(new File(metaDataFolder, "label_translator.ser"));
    switch(matchScoreTypeString) {
        case "es_original":
            matchScoreType = FeatureLoader.MatchScoreType.ES_ORIGINAL;
            break;
        case "binary":
            matchScoreType = FeatureLoader.MatchScoreType.BINARY;
            break;
        case "frequency":
            matchScoreType = FeatureLoader.MatchScoreType.FREQUENCY;
            break;
        case "tfifl":
            matchScoreType = FeatureLoader.MatchScoreType.TFIFL;
            break;
        default:
            throw new IllegalArgumentException("unknown ngramMatchScoreType");
    }
    double[][] labels = loadLabels(config, index, idTranslator, labelTranslator);
    int numLabels = labels.length;
    int toKeep = config.getInt("train.feature.ngram.selectPerLabel");
    List<BoundedBlockPriorityQueue<Pair<Ngram, Double>>> queues = new ArrayList<>();
    Comparator<Pair<Ngram, Double>> comparator = Comparator.comparing(p -> p.getSecond());
    for (int l = 0; l < numLabels; l++) {
        queues.add(new BoundedBlockPriorityQueue<>(toKeep, comparator));
    }
    FeatureList featureList = (FeatureList) Serialization.deserialize(new File(metaDataFolder, "feature_list.ser"));
    featureList.getAll().stream().parallel().filter(feature -> feature instanceof Ngram).map(feature -> (Ngram) feature).filter(ngram -> ngram.getN() > 1).forEach(ngram -> {
        double[] scores = StumpSelector.scores(index, labels, ngram, idTranslator, matchScoreType, docFilter);
        for (int l = 0; l < numLabels; l++) {
            queues.get(l).add(new Pair<>(ngram, scores[l]));
        }
    });
    Set<Ngram> kept = new HashSet<>();
    StringBuilder stringBuilder = new StringBuilder();
    for (int l = 0; l < numLabels; l++) {
        stringBuilder.append("-------------------------").append("\n");
        stringBuilder.append(labelTranslator.toExtLabel(l)).append(":").append("\n");
        BoundedBlockPriorityQueue<Pair<Ngram, Double>> queue = queues.get(l);
        while (queue.size() > 0) {
            Ngram ngram = queue.poll().getFirst();
            kept.add(ngram);
            stringBuilder.append(ngram.getNgram()).append(", ");
        }
        stringBuilder.append("\n");
    }
    File selectionFile = new File(metaDataFolder, "selected_ngrams.txt");
    FileUtils.writeStringToFile(selectionFile, stringBuilder.toString());
    logger.info("finish ngram selection");
    logger.info("selected ngrams are written to " + selectionFile.getAbsolutePath());
    // after feature selection, overwrite the feature_list.ser file; rename old files
    FeatureList selectedFeatures = new FeatureList();
    for (Feature feature : featureList.getAll()) {
        if (!(feature instanceof Ngram)) {
            selectedFeatures.add(feature);
        }
        if ((feature instanceof Ngram) && ((Ngram) feature).getN() == 1) {
            selectedFeatures.add(feature);
        }
        if ((feature instanceof Ngram) && ((Ngram) feature).getN() > 1 && kept.contains(feature)) {
            selectedFeatures.add(feature);
        }
    }
    FileUtils.copyFile(new File(metaDataFolder, "feature_list.ser"), new File(metaDataFolder, "feature_list_all.ser"));
    FileUtils.copyFile(new File(metaDataFolder, "feature_list.txt"), new File(metaDataFolder, "feature_list_all.txt"));
    Serialization.serialize(selectedFeatures, new File(metaDataFolder, "feature_list.ser"));
    try (BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(metaDataFolder, "feature_list.txt")))) {
        for (Feature feature : selectedFeatures.getAll()) {
            bufferedWriter.write(feature.toString());
            bufferedWriter.newLine();
        }
    }
}
Also used : java.util.logging(java.util.logging) java.util(java.util) BoundedBlockPriorityQueue(edu.neu.ccs.pyramid.util.BoundedBlockPriorityQueue) Multiset(com.google.common.collect.Multiset) NgramEnumerator(edu.neu.ccs.pyramid.feature_extraction.NgramEnumerator) edu.neu.ccs.pyramid.feature(edu.neu.ccs.pyramid.feature) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) FeatureLoader(edu.neu.ccs.pyramid.elasticsearch.FeatureLoader) Terms(org.elasticsearch.search.aggregations.bucket.terms.Terms) BufferedWriter(java.io.BufferedWriter) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) FileWriter(java.io.FileWriter) FileUtils(org.apache.commons.io.FileUtils) IOException(java.io.IOException) Collectors(java.util.stream.Collectors) File(java.io.File) MultiLabelIndex(edu.neu.ccs.pyramid.elasticsearch.MultiLabelIndex) ConcurrentHashMultiset(com.google.common.collect.ConcurrentHashMultiset) ESIndex(edu.neu.ccs.pyramid.elasticsearch.ESIndex) NgramTemplate(edu.neu.ccs.pyramid.feature_extraction.NgramTemplate) Serialization(edu.neu.ccs.pyramid.util.Serialization) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) StumpSelector(edu.neu.ccs.pyramid.feature_extraction.StumpSelector) Pattern(java.util.regex.Pattern) FileWriter(java.io.FileWriter) BufferedWriter(java.io.BufferedWriter) Pair(edu.neu.ccs.pyramid.util.Pair) BoundedBlockPriorityQueue(edu.neu.ccs.pyramid.util.BoundedBlockPriorityQueue) FeatureLoader(edu.neu.ccs.pyramid.elasticsearch.FeatureLoader) File(java.io.File)

Example 8 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class App1 method main.

public static void main(String[] args) throws Exception {
    if (args.length != 1) {
        throw new IllegalArgumentException("Please specify a properties file.");
    }
    Config config = new Config(args[0]);
    main(config);
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config)

Example 9 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class App2 method main.

public static void main(Config config) throws Exception {
    Logger logger = Logger.getAnonymousLogger();
    String logFile = config.getString("output.log");
    FileHandler fileHandler = null;
    if (!logFile.isEmpty()) {
        new File(logFile).getParentFile().mkdirs();
        //todo should append?
        fileHandler = new FileHandler(logFile, true);
        java.util.logging.Formatter formatter = new SimpleFormatter();
        fileHandler.setFormatter(formatter);
        logger.addHandler(fileHandler);
        logger.setUseParentHandlers(false);
    }
    logger.info(config.toString());
    new File(config.getString("output.folder")).mkdirs();
    if (config.getBoolean("train")) {
        train(config, logger);
        if (config.getString("predict.target").equals("macroFMeasure")) {
            logger.info("predict.target=macroFMeasure,  user needs to run 'tune' before predictions can be made. " + "Reports will be generated after tuning.");
        } else {
            if (config.getBoolean("train.generateReports")) {
                report(config, config.getString("input.trainData"), logger);
            }
        }
        File metaDataFolder = new File(config.getString("input.folder"), "meta_data");
        config.store(new File(metaDataFolder, "saved_config_app2"));
    }
    if (config.getBoolean("tune")) {
        tuneForMacroF(config, logger);
        File metaDataFolder = new File(config.getString("input.folder"), "meta_data");
        Config savedConfig = new Config(new File(metaDataFolder, "saved_config_app2"));
        if (savedConfig.getBoolean("train.generateReports")) {
            report(config, config.getString("input.trainData"), logger);
        }
    }
    if (config.getBoolean("test")) {
        report(config, config.getString("input.testData"), logger);
    }
    if (fileHandler != null) {
        fileHandler.close();
    }
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config) SimpleFormatter(java.util.logging.SimpleFormatter) Logger(java.util.logging.Logger) File(java.io.File) FileHandler(java.util.logging.FileHandler)

Example 10 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMGB method reportHammingPrediction.

private static void reportHammingPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
    System.out.println("============================================================");
    System.out.println("Making predictions on test set with the instance Hamming loss optimal predictor");
    String output = config.getString("output.dir");
    MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
    marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
    MultiLabel[] predictions = marginalPredictor.predict(dataSet);
    MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
    System.out.println("test performance with the instance Hamming loss optimal predictor");
    System.out.println(mlMeasures);
    File performanceFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "performance.txt").toFile();
    FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
    System.out.println("test performance is saved to " + performanceFile.toString());
    // Here we do not use approximation
    double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
    File predictionFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "predictions.txt").toFile();
    try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(predictions[i].toString());
            br.write(":");
            br.write("" + setProbs[i]);
            br.newLine();
        }
    }
    System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
    System.out.println("============================================================");
}
Also used : IntStream(java.util.stream.IntStream) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) TRECFormat(edu.neu.ccs.pyramid.dataset.TRECFormat) java.util(java.util) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DataSetUtil(edu.neu.ccs.pyramid.dataset.DataSetUtil) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) FileWriter(java.io.FileWriter) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) File(java.io.File) BufferedWriter(java.io.BufferedWriter)

Aggregations

Config (edu.neu.ccs.pyramid.configuration.Config)59 File (java.io.File)35 Collectors (java.util.stream.Collectors)18 FileUtils (org.apache.commons.io.FileUtils)18 StopWatch (org.apache.commons.lang3.time.StopWatch)18 Serialization (edu.neu.ccs.pyramid.util.Serialization)17 BufferedWriter (java.io.BufferedWriter)17 FileWriter (java.io.FileWriter)17 Paths (java.nio.file.Paths)17 IntStream (java.util.stream.IntStream)16 Pair (edu.neu.ccs.pyramid.util.Pair)15 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)14 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)14 java.util (java.util)14 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)13 PrintUtil (edu.neu.ccs.pyramid.util.PrintUtil)13 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)12 edu.neu.ccs.pyramid.multilabel_classification.cbm (edu.neu.ccs.pyramid.multilabel_classification.cbm)12 ListUtil (edu.neu.ccs.pyramid.util.ListUtil)12 IOException (java.io.IOException)8