Search in sources :

Example 6 with Pair

use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.

the class BMSelector method selectAll.

public static Pair<BM, double[][]> selectAll(int numClasses, MultiLabel[] multiLabels, int numClusters) {
    DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(multiLabels.length).numFeatures(numClasses).density(Density.SPARSE_RANDOM).build();
    for (int i = 0; i < multiLabels.length; i++) {
        MultiLabel multiLabel = multiLabels[i];
        for (int label : multiLabel.getMatchedLabels()) {
            dataSet.setFeatureValue(i, label, 1);
        }
    }
    BMTrainer trainer = BMSelector.selectTrainer(dataSet, numClusters, 10);
    //        System.out.println("bm = "+trainer.bm);
    //        System.out.println("gamma = "+ Arrays.deepToString(trainer.gammas));
    Pair<BM, double[][]> pair = new Pair<>();
    pair.setFirst(trainer.getBm());
    pair.setSecond(trainer.gammas);
    return pair;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DataSet(edu.neu.ccs.pyramid.dataset.DataSet) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 7 with Pair

use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.

the class ArffFormat method writeMatrixFile.

private static void writeMatrixFile(ClfDataSet dataSet, File arffFile) {
    File matrixFile = new File(arffFile, ARFF_MATRIX_FILE_NAME);
    int numDataPoints = dataSet.getNumDataPoints();
    int numFeatures = dataSet.getNumFeatures();
    int[] labels = dataSet.getLabels();
    try (BufferedWriter bw = new BufferedWriter(new FileWriter(matrixFile))) {
        bw.write("@RELATION MATRIX" + "\n");
        for (int i = 0; i < numFeatures; i++) {
            bw.write("@ATTRIBUTE " + i + " NUMERIC" + "\n");
        }
        bw.write("@ATTRIBUTE class {0");
        for (int i = 1; i < dataSet.getNumClasses(); i++) {
            bw.write("," + i);
        }
        bw.write("}" + "\n");
        bw.write("@DATA" + "\n");
        for (int i = 0; i < numDataPoints; i++) {
            int label = labels[i];
            bw.write("{");
            Vector vector = dataSet.getRow(i);
            // only write non-zeros
            // only write non-zeros
            List<Pair<Integer, Double>> pairs = new ArrayList<>();
            for (Vector.Element element : vector.nonZeroes()) {
                Pair<Integer, Double> pair = new Pair<>(element.index(), element.get());
                pairs.add(pair);
            }
            Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(Pair::getFirst);
            List<Pair<Integer, Double>> sorted = pairs.stream().sorted(comparator).collect(Collectors.toList());
            for (Pair<Integer, Double> pair : sorted) {
                bw.write(pair.getFirst() + ":" + pair.getSecond() + " ");
            }
            bw.write(numFeatures + " " + label + "}" + "\n");
        }
    } catch (IOException e) {
        e.printStackTrace();
    }
}
Also used : FileWriter(java.io.FileWriter) ArrayList(java.util.ArrayList) IOException(java.io.IOException) BufferedWriter(java.io.BufferedWriter) File(java.io.File) Vector(org.apache.mahout.math.Vector) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 8 with Pair

use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.

the class CBMInspector method covariance.

public static void covariance(CBM CBM, Vector vector, LabelTranslator labelTranslator) {
    int numClusters = CBM.getNumComponents();
    int numClasses = CBM.getNumClasses();
    double[] proportions = CBM.getMultiClassClassifier().predictClassProbs(vector);
    double[][] probabilities = new double[numClusters][numClasses];
    for (int k = 0; k < numClusters; k++) {
        for (int l = 0; l < numClasses; l++) {
            probabilities[k][l] = CBM.getBinaryClassifiers()[k][l].predictClassProb(vector, 1);
        }
    }
    // column vector
    Access2D.Builder<PrimitiveMatrix> meanBuilder = factory.getBuilder(numClasses, 1);
    for (int l = 0; l < numClasses; l++) {
        double sum = 0;
        for (int k = 0; k < numClusters; k++) {
            sum += proportions[k] * probabilities[k][l];
        }
        meanBuilder.set(l, 0, sum);
    }
    BasicMatrix mean = meanBuilder.build();
    //        System.out.println(mean);
    List<BasicMatrix> mus = new ArrayList<>();
    for (int k = 0; k < numClusters; k++) {
        Access2D.Builder<PrimitiveMatrix> muBuilder = factory.getBuilder(numClasses, 1);
        for (int l = 0; l < numClasses; l++) {
            muBuilder.set(l, 0, probabilities[k][l]);
        }
        BasicMatrix muK = muBuilder.build();
        mus.add(muK);
    }
    List<BasicMatrix> sigmas = new ArrayList<>();
    for (int k = 0; k < numClusters; k++) {
        Access2D.Builder<PrimitiveMatrix> sigmaBuilder = factory.getBuilder(numClasses, numClasses);
        for (int l = 0; l < numClasses; l++) {
            double v = probabilities[k][l] * (1 - probabilities[k][l]);
            sigmaBuilder.set(l, l, v);
        }
        BasicMatrix sigmaK = sigmaBuilder.build();
        sigmas.add(sigmaK);
    }
    BasicMatrix covariance = factory.makeZero(numClasses, numClasses);
    for (int k = 0; k < numClusters; k++) {
        BasicMatrix muk = mus.get(k);
        BasicMatrix toadd = (sigmas.get(k).add(muk.multiply(muk.transpose()))).multiply(proportions[k]);
        covariance = covariance.add(toadd);
    }
    covariance = covariance.subtract(mean.multiply(mean.transpose()));
    //        System.out.println("covariance = "+ Matrices.display(covariance));
    Access2D.Builder<PrimitiveMatrix> correlationBuilder = factory.getBuilder(numClasses, numClasses);
    for (int l = 0; l < numClasses; l++) {
        for (int j = 0; j < numClasses; j++) {
            double v = covariance.get(l, j).doubleValue() / (Math.sqrt(covariance.get(l, l).doubleValue()) * Math.sqrt(covariance.get(j, j).doubleValue()));
            correlationBuilder.set(l, j, v);
        }
    }
    BasicMatrix correlation = correlationBuilder.build();
    //        System.out.println("correlation = "+ Matrices.display(correlation));
    List<Pair<String, Double>> list = new ArrayList<>();
    for (int l = 0; l < numClasses; l++) {
        for (int j = 0; j < l; j++) {
            String s = "" + labelTranslator.toExtLabel(l) + ", " + labelTranslator.toExtLabel(j);
            double v = correlation.get(l, j).doubleValue();
            Pair<String, Double> pair = new Pair<>(s, v);
            list.add(pair);
        }
    }
    Comparator<Pair<String, Double>> comparator = Comparator.comparing(pair -> Math.abs(pair.getSecond()));
    List<Pair<String, Double>> top = list.stream().sorted(comparator.reversed()).limit(20).collect(Collectors.toList());
    System.out.println(top);
}
Also used : PrimitiveMatrix(org.ojalgo.matrix.PrimitiveMatrix) BasicMatrix(org.ojalgo.matrix.BasicMatrix) Access2D(org.ojalgo.access.Access2D) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 9 with Pair

use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.

the class GeneralF1Predictor method showSupportPrediction.

public static Analysis showSupportPrediction(List<MultiLabel> combinations, double[] probs, MultiLabel truth, MultiLabel prediction, int numClasses) {
    int truthIndex = 0;
    for (int i = 0; i < combinations.size(); i++) {
        if (combinations.get(i).equals(truth)) {
            truthIndex = i;
            break;
        }
    }
    double[] trueJoint = new double[combinations.size()];
    trueJoint[truthIndex] = 1;
    double kl = KLDivergence.kl(trueJoint, probs);
    List<Pair<MultiLabel, Double>> list = new ArrayList<>();
    for (int i = 0; i < combinations.size(); i++) {
        list.add(new Pair<>(combinations.get(i), probs[i]));
    }
    Comparator<Pair<MultiLabel, Double>> comparator = Comparator.comparing(a -> a.getSecond());
    List<Pair<MultiLabel, Double>> sorted = list.stream().sorted(comparator.reversed()).filter(pair -> pair.getSecond() > 0.01).collect(Collectors.toList());
    double expectedF1Prediction = expectedF1(combinations, probs, prediction, numClasses);
    double expectedF1Truth = expectedF1(combinations, probs, truth, numClasses);
    double actualF1 = new InstanceAverage(numClasses, truth, prediction).getF1();
    StringBuilder jointString = new StringBuilder();
    for (int i = 0; i < sorted.size(); i++) {
        jointString.append(sorted.get(i).getFirst()).append(":").append(sorted.get(i).getSecond()).append(", ");
    }
    Analysis analysis = new Analysis();
    analysis.expectedF1Prediction = expectedF1Prediction;
    analysis.expectedF1Truth = expectedF1Truth;
    analysis.actualF1 = actualF1;
    analysis.kl = kl;
    analysis.prediction = prediction;
    analysis.truth = truth;
    analysis.joint = jointString.toString();
    return analysis;
}
Also used : Arrays(java.util.Arrays) ArgSort(edu.neu.ccs.pyramid.util.ArgSort) Multiset(com.google.common.collect.Multiset) DenseVector(org.apache.mahout.math.DenseVector) DenseMatrix(org.apache.mahout.math.DenseMatrix) Matrix(org.apache.mahout.math.Matrix) Collectors(java.util.stream.Collectors) InstanceAverage(edu.neu.ccs.pyramid.eval.InstanceAverage) ArrayList(java.util.ArrayList) KLDivergence(edu.neu.ccs.pyramid.eval.KLDivergence) List(java.util.List) ConcurrentHashMultiset(com.google.common.collect.ConcurrentHashMultiset) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) Vector(org.apache.mahout.math.Vector) Enumerator(edu.neu.ccs.pyramid.multilabel_classification.Enumerator) Comparator(java.util.Comparator) Pair(edu.neu.ccs.pyramid.util.Pair) ArrayList(java.util.ArrayList) InstanceAverage(edu.neu.ccs.pyramid.eval.InstanceAverage) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 10 with Pair

use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.

the class ClusterLabels method getCluster.

private static List<WordFrequency> getCluster(BM bm, int k) throws Exception {
    BernoulliDistribution[][] distributions = bm.getDistributions();
    List<Pair<String, Double>> pairs = new ArrayList<>();
    for (int d = 0; d < bm.getDimension(); d++) {
        Pair<String, Double> pair = new Pair<>(bm.getNames().get(d), distributions[k][d].getP());
        pairs.add(pair);
    }
    Comparator<Pair<String, Double>> comparator = Comparator.comparing(Pair::getSecond);
    List<Pair<String, Double>> sorted = pairs.stream().sorted(comparator.reversed()).collect(Collectors.toList());
    List<WordFrequency> frequencies = new ArrayList<>();
    double sum = sorted.stream().filter(pair -> pair.getSecond() > 0).limit(20).mapToDouble(Pair::getSecond).sum();
    sorted.stream().filter(pair -> pair.getSecond() > 0).limit(20).forEach(pair -> {
        WordFrequency wordFrequency = new WordFrequency(pair.getFirst(), (int) (pair.getSecond() * 200 / sum));
        frequencies.add(wordFrequency);
    });
    return frequencies;
}
Also used : edu.neu.ccs.pyramid.util(edu.neu.ccs.pyramid.util) java.util(java.util) ArgSort(edu.neu.ccs.pyramid.util.ArgSort) CollisionMode(com.kennycason.kumo.CollisionMode) CenterWordStart(com.kennycason.kumo.wordstart.CenterWordStart) Random(java.util.Random) BMTrainer(edu.neu.ccs.pyramid.clustering.bm.BMTrainer) ArrayList(java.util.ArrayList) LinearFontScalar(com.kennycason.kumo.font.scale.LinearFontScalar) RectangleBackground(com.kennycason.kumo.bg.RectangleBackground) WordCloud(com.kennycason.kumo.WordCloud) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) BernoulliDistribution(edu.neu.ccs.pyramid.util.BernoulliDistribution) AngleGenerator(com.kennycason.kumo.image.AngleGenerator) FileUtils(org.apache.commons.io.FileUtils) Collectors(java.util.stream.Collectors) ColorPalette(com.kennycason.kumo.palette.ColorPalette) File(java.io.File) java.awt(java.awt) List(java.util.List) Serialization(edu.neu.ccs.pyramid.util.Serialization) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) WordFrequency(com.kennycason.kumo.WordFrequency) Comparator(java.util.Comparator) BM(edu.neu.ccs.pyramid.clustering.bm.BM) ArrayList(java.util.ArrayList) WordFrequency(com.kennycason.kumo.WordFrequency) Pair(edu.neu.ccs.pyramid.util.Pair)

Aggregations

Pair (edu.neu.ccs.pyramid.util.Pair)22 Vector (org.apache.mahout.math.Vector)16 ArrayList (java.util.ArrayList)10 File (java.io.File)7 Collectors (java.util.stream.Collectors)5 Config (edu.neu.ccs.pyramid.configuration.Config)4 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)4 BufferedWriter (java.io.BufferedWriter)4 FileWriter (java.io.FileWriter)4 IOException (java.io.IOException)4 java.util (java.util)4 ConcurrentHashMultiset (com.google.common.collect.ConcurrentHashMultiset)2 Multiset (com.google.common.collect.Multiset)2 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)2 DataSetType (edu.neu.ccs.pyramid.dataset.DataSetType)2 RegDataSet (edu.neu.ccs.pyramid.dataset.RegDataSet)2 ArgSort (edu.neu.ccs.pyramid.util.ArgSort)2 Serialization (edu.neu.ccs.pyramid.util.Serialization)2 Paths (java.nio.file.Paths)2 Comparator (java.util.Comparator)2