Search in sources :

Example 1 with Pair

use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.

the class App1 method ngramSelection.

private static void ngramSelection(Config config, MultiLabelIndex index, String docFilter, Logger logger) throws Exception {
    logger.info("start ngram selection");
    File metaDataFolder = new File(config.getString("output.folder"), "meta_data");
    FeatureLoader.MatchScoreType matchScoreType;
    String matchScoreTypeString = config.getString("train.feature.ngram.matchScoreType");
    String[] indexIds = getDocsForSplitFromQuery(index, config.getString("train.splitQuery"));
    IdTranslator idTranslator = loadIdTranslator(indexIds);
    LabelTranslator labelTranslator = (LabelTranslator) Serialization.deserialize(new File(metaDataFolder, "label_translator.ser"));
    switch(matchScoreTypeString) {
        case "es_original":
            matchScoreType = FeatureLoader.MatchScoreType.ES_ORIGINAL;
            break;
        case "binary":
            matchScoreType = FeatureLoader.MatchScoreType.BINARY;
            break;
        case "frequency":
            matchScoreType = FeatureLoader.MatchScoreType.FREQUENCY;
            break;
        case "tfifl":
            matchScoreType = FeatureLoader.MatchScoreType.TFIFL;
            break;
        default:
            throw new IllegalArgumentException("unknown ngramMatchScoreType");
    }
    double[][] labels = loadLabels(config, index, idTranslator, labelTranslator);
    int numLabels = labels.length;
    int toKeep = config.getInt("train.feature.ngram.selectPerLabel");
    List<BoundedBlockPriorityQueue<Pair<Ngram, Double>>> queues = new ArrayList<>();
    Comparator<Pair<Ngram, Double>> comparator = Comparator.comparing(p -> p.getSecond());
    for (int l = 0; l < numLabels; l++) {
        queues.add(new BoundedBlockPriorityQueue<>(toKeep, comparator));
    }
    FeatureList featureList = (FeatureList) Serialization.deserialize(new File(metaDataFolder, "feature_list.ser"));
    featureList.getAll().stream().parallel().filter(feature -> feature instanceof Ngram).map(feature -> (Ngram) feature).filter(ngram -> ngram.getN() > 1).forEach(ngram -> {
        double[] scores = StumpSelector.scores(index, labels, ngram, idTranslator, matchScoreType, docFilter);
        for (int l = 0; l < numLabels; l++) {
            queues.get(l).add(new Pair<>(ngram, scores[l]));
        }
    });
    Set<Ngram> kept = new HashSet<>();
    StringBuilder stringBuilder = new StringBuilder();
    for (int l = 0; l < numLabels; l++) {
        stringBuilder.append("-------------------------").append("\n");
        stringBuilder.append(labelTranslator.toExtLabel(l)).append(":").append("\n");
        BoundedBlockPriorityQueue<Pair<Ngram, Double>> queue = queues.get(l);
        while (queue.size() > 0) {
            Ngram ngram = queue.poll().getFirst();
            kept.add(ngram);
            stringBuilder.append(ngram.getNgram()).append(", ");
        }
        stringBuilder.append("\n");
    }
    File selectionFile = new File(metaDataFolder, "selected_ngrams.txt");
    FileUtils.writeStringToFile(selectionFile, stringBuilder.toString());
    logger.info("finish ngram selection");
    logger.info("selected ngrams are written to " + selectionFile.getAbsolutePath());
    // after feature selection, overwrite the feature_list.ser file; rename old files
    FeatureList selectedFeatures = new FeatureList();
    for (Feature feature : featureList.getAll()) {
        if (!(feature instanceof Ngram)) {
            selectedFeatures.add(feature);
        }
        if ((feature instanceof Ngram) && ((Ngram) feature).getN() == 1) {
            selectedFeatures.add(feature);
        }
        if ((feature instanceof Ngram) && ((Ngram) feature).getN() > 1 && kept.contains(feature)) {
            selectedFeatures.add(feature);
        }
    }
    FileUtils.copyFile(new File(metaDataFolder, "feature_list.ser"), new File(metaDataFolder, "feature_list_all.ser"));
    FileUtils.copyFile(new File(metaDataFolder, "feature_list.txt"), new File(metaDataFolder, "feature_list_all.txt"));
    Serialization.serialize(selectedFeatures, new File(metaDataFolder, "feature_list.ser"));
    try (BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(metaDataFolder, "feature_list.txt")))) {
        for (Feature feature : selectedFeatures.getAll()) {
            bufferedWriter.write(feature.toString());
            bufferedWriter.newLine();
        }
    }
}
Also used : java.util.logging(java.util.logging) java.util(java.util) BoundedBlockPriorityQueue(edu.neu.ccs.pyramid.util.BoundedBlockPriorityQueue) Multiset(com.google.common.collect.Multiset) NgramEnumerator(edu.neu.ccs.pyramid.feature_extraction.NgramEnumerator) edu.neu.ccs.pyramid.feature(edu.neu.ccs.pyramid.feature) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) FeatureLoader(edu.neu.ccs.pyramid.elasticsearch.FeatureLoader) Terms(org.elasticsearch.search.aggregations.bucket.terms.Terms) BufferedWriter(java.io.BufferedWriter) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) FileWriter(java.io.FileWriter) FileUtils(org.apache.commons.io.FileUtils) IOException(java.io.IOException) Collectors(java.util.stream.Collectors) File(java.io.File) MultiLabelIndex(edu.neu.ccs.pyramid.elasticsearch.MultiLabelIndex) ConcurrentHashMultiset(com.google.common.collect.ConcurrentHashMultiset) ESIndex(edu.neu.ccs.pyramid.elasticsearch.ESIndex) NgramTemplate(edu.neu.ccs.pyramid.feature_extraction.NgramTemplate) Serialization(edu.neu.ccs.pyramid.util.Serialization) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) StumpSelector(edu.neu.ccs.pyramid.feature_extraction.StumpSelector) Pattern(java.util.regex.Pattern) FileWriter(java.io.FileWriter) BufferedWriter(java.io.BufferedWriter) Pair(edu.neu.ccs.pyramid.util.Pair) BoundedBlockPriorityQueue(edu.neu.ccs.pyramid.util.BoundedBlockPriorityQueue) FeatureLoader(edu.neu.ccs.pyramid.elasticsearch.FeatureLoader) File(java.io.File)

Example 2 with Pair

use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.

the class TRECFormat method writeMatrixFile.

private static void writeMatrixFile(ClfDataSet dataSet, File trecFile) {
    File matrixFile = new File(trecFile, TREC_MATRIX_FILE_NAME);
    int numDataPoints = dataSet.getNumDataPoints();
    int[] labels = dataSet.getLabels();
    try (BufferedWriter bw = new BufferedWriter(new FileWriter(matrixFile))) {
        for (int i = 0; i < numDataPoints; i++) {
            int label = labels[i];
            bw.write(label + " ");
            Vector vector = dataSet.getRow(i);
            // only write non-zeros
            List<Pair<Integer, Double>> pairs = new ArrayList<>();
            for (Vector.Element element : vector.nonZeroes()) {
                Pair<Integer, Double> pair = new Pair<>(element.index(), element.get());
                pairs.add(pair);
            }
            Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(Pair::getFirst);
            List<Pair<Integer, Double>> sorted = pairs.stream().sorted(comparator).collect(Collectors.toList());
            for (Pair<Integer, Double> pair : sorted) {
                bw.write(pair.getFirst() + ":" + pair.getSecond() + " ");
            }
            bw.write("\n");
        }
    } catch (IOException e) {
        e.printStackTrace();
    }
}
Also used : ArrayList(java.util.ArrayList) Vector(org.apache.mahout.math.Vector) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 3 with Pair

use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.

the class LibSvmFormat method save.

public static void save(RegDataSet dataSet, String libSvmFile) {
    File matrixFile = new File(libSvmFile);
    int numDataPoints = dataSet.getNumDataPoints();
    double[] labels = dataSet.getLabels();
    try (BufferedWriter bw = new BufferedWriter(new FileWriter(matrixFile))) {
        for (int i = 0; i < numDataPoints; i++) {
            double label = labels[i];
            bw.write(label + " ");
            Vector vector = dataSet.getRow(i);
            // only write non-zeros
            List<Pair<Integer, Double>> pairs = new ArrayList<>();
            for (Vector.Element element : vector.nonZeroes()) {
                Pair<Integer, Double> pair = new Pair<>(element.index() + 1, element.get());
                pairs.add(pair);
            }
            Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(Pair::getFirst);
            List<Pair<Integer, Double>> sorted = pairs.stream().sorted(comparator).collect(Collectors.toList());
            for (Pair<Integer, Double> pair : sorted) {
                bw.write(pair.getFirst() + ":" + pair.getSecond() + " ");
            }
            bw.write("\n");
        }
    } catch (IOException e) {
        e.printStackTrace();
    }
}
Also used : Vector(org.apache.mahout.math.Vector) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 4 with Pair

use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.

the class DataSetUtil method toMultiClass.

public static Pair<ClfDataSet, Translator<MultiLabel>> toMultiClass(MultiLabelClfDataSet dataSet) {
    int numDataPoints = dataSet.getNumDataPoints();
    int numFeatures = dataSet.getNumFeatures();
    List<MultiLabel> multiLabels = DataSetUtil.gatherMultiLabels(dataSet);
    Translator<MultiLabel> translator = new Translator<>();
    translator.addAll(multiLabels);
    ClfDataSet clfDataSet = ClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).dense(dataSet.isDense()).missingValue(dataSet.hasMissingValue()).numClasses(translator.size()).build();
    for (int i = 0; i < numDataPoints; i++) {
        //only copy non-zero elements
        Vector vector = dataSet.getRow(i);
        for (Vector.Element element : vector.nonZeroes()) {
            int featureIndex = element.index();
            double value = element.get();
            clfDataSet.setFeatureValue(i, featureIndex, value);
        }
        int label = translator.getIndex(dataSet.getMultiLabels()[i]);
        clfDataSet.setLabel(i, label);
    }
    List<String> extLabels = multiLabels.stream().map(MultiLabel::toString).collect(Collectors.toList());
    LabelTranslator labelTranslator = new LabelTranslator(extLabels);
    clfDataSet.setLabelTranslator(labelTranslator);
    clfDataSet.setFeatureList(dataSet.getFeatureList());
    return new Pair<>(clfDataSet, translator);
}
Also used : Translator(edu.neu.ccs.pyramid.util.Translator) Vector(org.apache.mahout.math.Vector) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 5 with Pair

use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.

the class LinearRegElasticNet method main.

public static void main(String[] args) throws Exception {
    if (args.length != 1) {
        throw new IllegalArgumentException("Please specify a properties file.");
    }
    Config config = new Config(args[0]);
    System.out.println(config);
    String output = config.getString("output.folder");
    new File(output).mkdirs();
    String sparsity = config.getString("featureMatrix.sparsity").toLowerCase();
    DataSetType dataSetType = null;
    switch(sparsity) {
        case "dense":
            dataSetType = DataSetType.REG_DENSE;
            break;
        case "sparse":
            dataSetType = DataSetType.REG_SPARSE;
            break;
        default:
            throw new IllegalArgumentException("featureMatrix.sparsity can be either dense or sparse");
    }
    RegDataSet trainSet = TRECFormat.loadRegDataSet(config.getString("input.trainSet"), dataSetType, true);
    RegDataSet testSet = TRECFormat.loadRegDataSet(config.getString("input.testSet"), dataSetType, true);
    LinearRegression linearRegression = new LinearRegression(trainSet.getNumFeatures());
    ElasticNetLinearRegOptimizer optimizer = new ElasticNetLinearRegOptimizer(linearRegression, trainSet);
    optimizer.setRegularization(config.getDouble("regularization"));
    optimizer.setL1Ratio(config.getDouble("l1Ratio"));
    System.out.println("before training");
    System.out.println("training set RMSE = " + RMSE.rmse(linearRegression, trainSet));
    System.out.println("test set RMSE = " + RMSE.rmse(linearRegression, testSet));
    System.out.println("start training");
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    optimizer.optimize();
    System.out.println("training done");
    System.out.println("time spent on training = " + stopWatch);
    System.out.println("after training");
    System.out.println("training set RMSE = " + RMSE.rmse(linearRegression, trainSet));
    System.out.println("test set RMSE = " + RMSE.rmse(linearRegression, testSet));
    System.out.println("number of non-zeros weights in linear regression (not including bias) = " + linearRegression.getWeights().getWeightsWithoutBias().getNumNonZeroElements());
    List<Pair<Integer, Double>> sorted = new ArrayList<>();
    for (Vector.Element element : linearRegression.getWeights().getWeightsWithoutBias().nonZeroes()) {
        sorted.add(new Pair<>(element.index(), element.get()));
    }
    Comparator<Pair<Integer, Double>> comparatorByIndex = Comparator.comparing(pair -> pair.getFirst());
    sorted = sorted.stream().sorted(comparatorByIndex).collect(Collectors.toList());
    StringBuilder sb1 = new StringBuilder();
    for (Pair<Integer, Double> pair : sorted) {
        int index = pair.getFirst();
        sb1.append(index).append("(").append(trainSet.getFeatureList().get(index).getName()).append(")").append(":").append(pair.getSecond()).append("\n");
    }
    FileUtils.writeStringToFile(new File(output, "features_sorted_by_indices.txt"), sb1.toString());
    System.out.println("all selected features (sorted by indices) are saved to " + new File(output, "features_sorted_by_indices.txt").getAbsolutePath());
    Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(pair -> Math.abs(pair.getSecond()));
    sorted = sorted.stream().sorted(comparator.reversed()).collect(Collectors.toList());
    StringBuilder sb = new StringBuilder();
    for (Pair<Integer, Double> pair : sorted) {
        int index = pair.getFirst();
        sb.append(index).append("(").append(trainSet.getFeatureList().get(index).getName()).append(")").append(":").append(pair.getSecond()).append("\n");
    }
    FileUtils.writeStringToFile(new File(output, "features_sorted_by_weights.txt"), sb.toString());
    System.out.println("all selected features (sorted by absolute weights) are saved to " + new File(output, "features_sorted_by_weights.txt").getAbsolutePath());
    File reportFile = new File(output, "test_predictions.txt");
    report(linearRegression, testSet, reportFile);
    System.out.println("predictions on the test set are written to " + reportFile.getAbsolutePath());
}
Also used : DataSetType(edu.neu.ccs.pyramid.dataset.DataSetType) Config(edu.neu.ccs.pyramid.configuration.Config) ArrayList(java.util.ArrayList) StopWatch(org.apache.commons.lang3.time.StopWatch) ElasticNetLinearRegOptimizer(edu.neu.ccs.pyramid.regression.linear_regression.ElasticNetLinearRegOptimizer) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) File(java.io.File) LinearRegression(edu.neu.ccs.pyramid.regression.linear_regression.LinearRegression) Vector(org.apache.mahout.math.Vector) Pair(edu.neu.ccs.pyramid.util.Pair)

Aggregations

Pair (edu.neu.ccs.pyramid.util.Pair)22 Vector (org.apache.mahout.math.Vector)16 ArrayList (java.util.ArrayList)10 File (java.io.File)7 Collectors (java.util.stream.Collectors)5 Config (edu.neu.ccs.pyramid.configuration.Config)4 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)4 BufferedWriter (java.io.BufferedWriter)4 FileWriter (java.io.FileWriter)4 IOException (java.io.IOException)4 java.util (java.util)4 ConcurrentHashMultiset (com.google.common.collect.ConcurrentHashMultiset)2 Multiset (com.google.common.collect.Multiset)2 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)2 DataSetType (edu.neu.ccs.pyramid.dataset.DataSetType)2 RegDataSet (edu.neu.ccs.pyramid.dataset.RegDataSet)2 ArgSort (edu.neu.ccs.pyramid.util.ArgSort)2 Serialization (edu.neu.ccs.pyramid.util.Serialization)2 Paths (java.nio.file.Paths)2 Comparator (java.util.Comparator)2