Search in sources :

Example 11 with Pair

use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.

the class RegTreeInspector method featureImportance.

/**
     * pair contains feature name and reduction
     * @param tree
     * @return
     */
//    public static Map<Integer, Pair<String,Double>> featureImportance(RegressionTree tree){
//        List<Feature> featureList = tree.getFeatureList().getAll();
//        Map<Integer, Pair<String,Double>> map = new HashMap<>();
//        List<Node> nodes = tree.traverse();
//        nodes.stream().filter(node -> !node.isLeaf())
//                .forEach(node -> {
//                    int featureIndex = node.getFeatureIndex();
//                    String featureName = featureList.get(node.getFeatureIndex()).getName();
//                    double reduction = node.getReduction();
//                    Pair<String,Double> oldPair = map.getOrDefault(featureIndex, new Pair<>(featureName,0.0));
//                    Pair<String, Double> newPair = new Pair<>(featureName,oldPair.getSecond()+reduction);
//                    map.put(featureIndex, newPair);
//                });
//        return map;
//    }
public static Map<Feature, Double> featureImportance(RegressionTree tree) {
    FeatureList featureList = tree.getFeatureList();
    Map<Feature, Double> map = new HashMap<>();
    List<Node> nodes = tree.traverse();
    nodes.stream().filter(node -> !node.isLeaf()).forEach(node -> {
        int featureIndex = node.getFeatureIndex();
        Feature feature = featureList.get(featureIndex);
        double reduction = node.getReduction();
        double oldValue = map.getOrDefault(feature, 0.0);
        double newValue = reduction + oldValue;
        map.put(feature, newValue);
    });
    return map;
}
Also used : java.util(java.util) Feature(edu.neu.ccs.pyramid.feature.Feature) Vector(org.apache.mahout.math.Vector) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList) Collectors(java.util.stream.Collectors) Pair(edu.neu.ccs.pyramid.util.Pair) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList) Feature(edu.neu.ccs.pyramid.feature.Feature)

Example 12 with Pair

use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.

the class DataSetUtil method sampleData.

public static Pair<DataSet, double[][]> sampleData(DataSet dataSet, double[][] targetDistribution, List<Integer> indices) {
    DataSet sample;
    int numClasses = targetDistribution[0].length;
    double[][] sampledTargets = new double[indices.size()][numClasses];
    sample = DataSetBuilder.getBuilder().dense(dataSet.isDense()).missingValue(dataSet.hasMissingValue()).numDataPoints(indices.size()).numFeatures(dataSet.getNumFeatures()).build();
    for (int i = 0; i < indices.size(); i++) {
        int indexInOld = indices.get(i);
        Vector oldVector = dataSet.getRow(indexInOld);
        double[] targets = targetDistribution[indexInOld];
        //copy label
        sampledTargets[i] = Arrays.copyOf(targets, targets.length);
        //copy row feature values, optimized for sparse vector
        for (Vector.Element element : oldVector.nonZeroes()) {
            sample.setFeatureValue(i, element.index(), element.get());
        }
    }
    sample.setFeatureList(dataSet.getFeatureList());
    //ignore idTranslator as we may have duplicate extIds
    return new Pair<>(sample, sampledTargets);
}
Also used : Vector(org.apache.mahout.math.Vector) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 13 with Pair

use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.

the class ArffFormat method writeMatrixFile.

private static void writeMatrixFile(MultiLabelClfDataSet dataSet, File arffFile) {
    File matrixFile = new File(arffFile, ARFF_MATRIX_FILE_NAME);
    int numDataPoints = dataSet.getNumDataPoints();
    int numFeatures = dataSet.getNumFeatures();
    MultiLabel[] multiLabels = dataSet.getMultiLabels();
    try (BufferedWriter bw = new BufferedWriter(new FileWriter(matrixFile))) {
        bw.write("@RELATION MATRIX" + "\n");
        for (int i = 0; i < numFeatures; i++) {
            bw.write("@ATTRIBUTE " + i + " NUMERIC" + "\n");
        }
        for (int i = 0; i < dataSet.getNumClasses(); i++) {
            bw.write("@ATTRIBUTE class " + i + " {0,1}" + "\n");
        }
        bw.write("@DATA" + "\n");
        for (int i = 0; i < numDataPoints; i++) {
            MultiLabel multiLabel = multiLabels[i];
            List<Integer> labels = multiLabel.getMatchedLabels().stream().sorted().collect(Collectors.toList());
            bw.write("{");
            Vector vector = dataSet.getRow(i);
            // only write non-zeros
            List<Pair<Integer, Double>> pairs = new ArrayList<>();
            for (Vector.Element element : vector.nonZeroes()) {
                Pair<Integer, Double> pair = new Pair<>(element.index(), element.get());
                pairs.add(pair);
            }
            Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(Pair::getFirst);
            List<Pair<Integer, Double>> sorted = pairs.stream().sorted(comparator).collect(Collectors.toList());
            for (Pair<Integer, Double> pair : sorted) {
                bw.write(pair.getFirst() + ":" + pair.getSecond() + " ");
            }
            for (int l = 0; l < labels.size() - 1; l++) {
                int label = labels.get(l) + numFeatures;
                bw.write(label + " 1,");
            }
            int label = labels.get(labels.size() - 1) + numFeatures;
            bw.write(label + " 1}" + "\n");
        }
    } catch (IOException e) {
        e.printStackTrace();
    }
}
Also used : FileWriter(java.io.FileWriter) ArrayList(java.util.ArrayList) IOException(java.io.IOException) BufferedWriter(java.io.BufferedWriter) File(java.io.File) Vector(org.apache.mahout.math.Vector) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 14 with Pair

use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.

the class ArffFormat method writeMatrixFile.

private static void writeMatrixFile(RegDataSet dataSet, File arffFile) {
    File matrixFile = new File(arffFile, ARFF_MATRIX_FILE_NAME);
    int numDataPoints = dataSet.getNumDataPoints();
    int numFeatures = dataSet.getNumFeatures();
    double[] labels = dataSet.getLabels();
    try (BufferedWriter bw = new BufferedWriter(new FileWriter(matrixFile))) {
        bw.write("@RELATION MATRIX" + "\n");
        for (int i = 0; i < numFeatures; i++) {
            bw.write("@ATTRIBUTE " + i + " NUMERIC" + "\n");
        }
        bw.write("@ATTRIBUTE class NUMERIC" + "\n");
        bw.write("@DATA" + "\n");
        for (int i = 0; i < numDataPoints; i++) {
            double label = labels[i];
            bw.write("{");
            Vector vector = dataSet.getRow(i);
            // only write non-zeros
            List<Pair<Integer, Double>> pairs = new ArrayList<>();
            for (Vector.Element element : vector.nonZeroes()) {
                Pair<Integer, Double> pair = new Pair<>(element.index(), element.get());
                pairs.add(pair);
            }
            Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(Pair::getFirst);
            List<Pair<Integer, Double>> sorted = pairs.stream().sorted(comparator).collect(Collectors.toList());
            for (Pair<Integer, Double> pair : sorted) {
                bw.write(pair.getFirst() + ":" + pair.getSecond() + " ");
            }
            bw.write(numFeatures + " " + label + "}" + "\n");
        }
    } catch (IOException e) {
        e.printStackTrace();
    }
}
Also used : FileWriter(java.io.FileWriter) ArrayList(java.util.ArrayList) IOException(java.io.IOException) BufferedWriter(java.io.BufferedWriter) File(java.io.File) Vector(org.apache.mahout.math.Vector) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 15 with Pair

use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.

the class TrustRegionNewtonOptimizer method trcg.

/**
     *
     * @param delta input
     * @param g input
     * @return s, r
     */
private Pair<Vector, Vector> trcg(double delta, Vector g) {
    int numColumns = loss.getNumColumns();
    double one = 1;
    Vector d = new DenseVector(numColumns);
    Vector Hd = new DenseVector(numColumns);
    double rTr, rnewTrnew, cgtol;
    Vector s = new DenseVector(numColumns);
    Vector r = new DenseVector(numColumns);
    Pair<Vector, Vector> result = new Pair<>();
    for (int i = 0; i < numColumns; i++) {
        s.set(i, 0);
        r.set(i, -g.get(i));
        d.set(i, r.get(i));
    }
    cgtol = 0.1 * g.norm(2);
    rTr = r.dot(r);
    while (true) {
        if (r.norm(2) <= cgtol) {
            break;
        }
        loss.Hv(d, Hd);
        double alpha = rTr / d.dot(Hd);
        daxpy(alpha, d, s);
        if (s.norm(2) > delta) {
            alpha = -alpha;
            daxpy(alpha, d, s);
            double std = s.dot(d);
            double sts = s.dot(s);
            double dtd = d.dot(d);
            double dsq = delta * delta;
            double rad = Math.sqrt(std * std + dtd * (dsq - sts));
            if (std >= 0)
                alpha = (dsq - sts) / (std + rad);
            else
                alpha = (rad - std) / dtd;
            daxpy(alpha, d, s);
            alpha = -alpha;
            daxpy(alpha, Hd, r);
            break;
        }
        alpha = -alpha;
        daxpy(alpha, Hd, r);
        rnewTrnew = r.dot(r);
        double beta = rnewTrnew / rTr;
        scale(beta, d);
        daxpy(one, r, d);
        rTr = rnewTrnew;
    }
    result.setFirst(s);
    result.setSecond(r);
    return result;
}
Also used : DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector) DenseVector(org.apache.mahout.math.DenseVector) Pair(edu.neu.ccs.pyramid.util.Pair)

Aggregations

Pair (edu.neu.ccs.pyramid.util.Pair)22 Vector (org.apache.mahout.math.Vector)16 ArrayList (java.util.ArrayList)10 File (java.io.File)7 Collectors (java.util.stream.Collectors)5 Config (edu.neu.ccs.pyramid.configuration.Config)4 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)4 BufferedWriter (java.io.BufferedWriter)4 FileWriter (java.io.FileWriter)4 IOException (java.io.IOException)4 java.util (java.util)4 ConcurrentHashMultiset (com.google.common.collect.ConcurrentHashMultiset)2 Multiset (com.google.common.collect.Multiset)2 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)2 DataSetType (edu.neu.ccs.pyramid.dataset.DataSetType)2 RegDataSet (edu.neu.ccs.pyramid.dataset.RegDataSet)2 ArgSort (edu.neu.ccs.pyramid.util.ArgSort)2 Serialization (edu.neu.ccs.pyramid.util.Serialization)2 Paths (java.nio.file.Paths)2 Comparator (java.util.Comparator)2