Search in sources :

Example 16 with Pair

use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.

the class GeneralF1Predictor method bestWithLengthK.

private Pair<MultiLabel, Double> bestWithLengthK(double[] deltaVector, int k) {
    int[] sortedIndcies = ArgSort.argSortDescending(deltaVector);
    MultiLabel multiLabel = new MultiLabel();
    double score = 0;
    for (int i = 0; i < k; i++) {
        int label = sortedIndcies[i];
        multiLabel.addLabel(label);
        score += deltaVector[label];
    }
    return new Pair<>(multiLabel, score);
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 17 with Pair

use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.

the class LibSvmFormat method save.

public static void save(ClfDataSet dataSet, String libSvmFile) {
    File matrixFile = new File(libSvmFile);
    int numDataPoints = dataSet.getNumDataPoints();
    int[] labels = dataSet.getLabels();
    try (BufferedWriter bw = new BufferedWriter(new FileWriter(matrixFile))) {
        for (int i = 0; i < numDataPoints; i++) {
            int label = labels[i];
            bw.write(label + " ");
            Vector vector = dataSet.getRow(i);
            // only write non-zeros
            List<Pair<Integer, Double>> pairs = new ArrayList<>();
            for (Vector.Element element : vector.nonZeroes()) {
                Pair<Integer, Double> pair = new Pair<>(element.index() + 1, element.get());
                pairs.add(pair);
            }
            Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(Pair::getFirst);
            List<Pair<Integer, Double>> sorted = pairs.stream().sorted(comparator).collect(Collectors.toList());
            for (Pair<Integer, Double> pair : sorted) {
                bw.write(pair.getFirst() + ":" + pair.getSecond() + " ");
            }
            bw.write("\n");
        }
    } catch (IOException e) {
        e.printStackTrace();
    }
}
Also used : Vector(org.apache.mahout.math.Vector) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 18 with Pair

use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.

the class LibSvmFormat method save.

public static void save(MultiLabelClfDataSet dataSet, String libSvmFile) {
    File matrixFile = new File(libSvmFile);
    int numDataPoints = dataSet.getNumDataPoints();
    MultiLabel[] multiLabels = dataSet.getMultiLabels();
    try (BufferedWriter bw = new BufferedWriter(new FileWriter(matrixFile))) {
        for (int i = 0; i < numDataPoints; i++) {
            MultiLabel multiLabel = multiLabels[i];
            List<Integer> labels = multiLabel.getMatchedLabels().stream().sorted().collect(Collectors.toList());
            for (int l = 0; l < labels.size(); l++) {
                bw.write(labels.get(l).toString());
                if (l != labels.size() - 1) {
                    bw.write(",");
                } else {
                    bw.write(" ");
                }
            }
            Vector vector = dataSet.getRow(i);
            // only write non-zeros
            List<Pair<Integer, Double>> pairs = new ArrayList<>();
            for (Vector.Element element : vector.nonZeroes()) {
                Pair<Integer, Double> pair = new Pair<>(element.index() + 1, element.get());
                pairs.add(pair);
            }
            Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(Pair::getFirst);
            List<Pair<Integer, Double>> sorted = pairs.stream().sorted(comparator).collect(Collectors.toList());
            for (Pair<Integer, Double> pair : sorted) {
                bw.write(pair.getFirst() + ":" + pair.getSecond() + " ");
            }
            bw.write("\n");
        }
    } catch (IOException e) {
        e.printStackTrace();
    }
}
Also used : Vector(org.apache.mahout.math.Vector) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 19 with Pair

use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.

the class TRECFormat method writeMatrixFile.

private static void writeMatrixFile(MultiLabelClfDataSet dataSet, File trecFile) {
    File matrixFile = new File(trecFile, TREC_MATRIX_FILE_NAME);
    int numDataPoints = dataSet.getNumDataPoints();
    MultiLabel[] multiLabels = dataSet.getMultiLabels();
    try (BufferedWriter bw = new BufferedWriter(new FileWriter(matrixFile))) {
        for (int i = 0; i < numDataPoints; i++) {
            MultiLabel multiLabel = multiLabels[i];
            List<Integer> labels = multiLabel.getMatchedLabels().stream().sorted().collect(Collectors.toList());
            for (int l = 0; l < labels.size(); l++) {
                bw.write(labels.get(l).toString());
                if (l != labels.size() - 1) {
                    bw.write(",");
                }
            }
            bw.write(" ");
            Vector vector = dataSet.getRow(i);
            // only write non-zeros
            List<Pair<Integer, Double>> pairs = new ArrayList<>();
            for (Vector.Element element : vector.nonZeroes()) {
                Pair<Integer, Double> pair = new Pair<>(element.index(), element.get());
                pairs.add(pair);
            }
            Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(Pair::getFirst);
            List<Pair<Integer, Double>> sorted = pairs.stream().sorted(comparator).collect(Collectors.toList());
            for (Pair<Integer, Double> pair : sorted) {
                bw.write(pair.getFirst() + ":" + pair.getSecond() + " ");
            }
            bw.write("\n");
        }
    } catch (IOException e) {
        e.printStackTrace();
    }
}
Also used : ArrayList(java.util.ArrayList) Vector(org.apache.mahout.math.Vector) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 20 with Pair

use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.

the class TRECFormat method writeMatrixFile.

private static void writeMatrixFile(RegDataSet dataSet, File trecFile) {
    File matrixFile = new File(trecFile, TREC_MATRIX_FILE_NAME);
    int numDataPoints = dataSet.getNumDataPoints();
    double[] labels = dataSet.getLabels();
    try (BufferedWriter bw = new BufferedWriter(new FileWriter(matrixFile))) {
        for (int i = 0; i < numDataPoints; i++) {
            double label = labels[i];
            bw.write(label + " ");
            Vector vector = dataSet.getRow(i);
            // only write non-zeros
            List<Pair<Integer, Double>> pairs = new ArrayList<>();
            for (Vector.Element element : vector.nonZeroes()) {
                Pair<Integer, Double> pair = new Pair<>(element.index(), element.get());
                pairs.add(pair);
            }
            Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(Pair::getFirst);
            List<Pair<Integer, Double>> sorted = pairs.stream().sorted(comparator).collect(Collectors.toList());
            for (Pair<Integer, Double> pair : sorted) {
                bw.write(pair.getFirst() + ":" + pair.getSecond() + " ");
            }
            bw.write("\n");
        }
    } catch (IOException e) {
        e.printStackTrace();
    }
}
Also used : ArrayList(java.util.ArrayList) Vector(org.apache.mahout.math.Vector) Pair(edu.neu.ccs.pyramid.util.Pair)

Aggregations

Pair (edu.neu.ccs.pyramid.util.Pair)22 Vector (org.apache.mahout.math.Vector)16 ArrayList (java.util.ArrayList)10 File (java.io.File)7 Collectors (java.util.stream.Collectors)5 Config (edu.neu.ccs.pyramid.configuration.Config)4 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)4 BufferedWriter (java.io.BufferedWriter)4 FileWriter (java.io.FileWriter)4 IOException (java.io.IOException)4 java.util (java.util)4 ConcurrentHashMultiset (com.google.common.collect.ConcurrentHashMultiset)2 Multiset (com.google.common.collect.Multiset)2 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)2 DataSetType (edu.neu.ccs.pyramid.dataset.DataSetType)2 RegDataSet (edu.neu.ccs.pyramid.dataset.RegDataSet)2 ArgSort (edu.neu.ccs.pyramid.util.ArgSort)2 Serialization (edu.neu.ccs.pyramid.util.Serialization)2 Paths (java.nio.file.Paths)2 Comparator (java.util.Comparator)2