Search in sources :

Example 6 with BernoulliDistribution

use of edu.neu.ccs.pyramid.util.BernoulliDistribution in project pyramid by cheng-li.

the class CBMTest method test2.

private static void test2() throws Exception {
    MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(2).numClasses(4).numDataPoints(1000).build();
    BernoulliDistribution bernoulliDistribution = new BernoulliDistribution(0.5);
    for (int n = 0; n < dataSet.getNumDataPoints(); n++) {
        for (int m = 0; m < dataSet.getNumFeatures(); m++) {
            int bit = bernoulliDistribution.sample();
            int flip = bit;
            if (Math.random() < 0.1) {
                flip = 1 - bit;
            }
            dataSet.setFeatureValue(n, m, bit);
            if (m == 0) {
                if (flip == 0) {
                    dataSet.addLabel(n, 0);
                } else {
                    dataSet.addLabel(n, 1);
                }
            } else {
                if (flip == 0) {
                    dataSet.addLabel(n, 2);
                } else {
                    dataSet.addLabel(n, 3);
                }
            }
        }
    }
    MultiLabelClfDataSet testSet = MLClfDataSetBuilder.getBuilder().numFeatures(2).numClasses(4).numDataPoints(100).build();
    for (int n = 0; n < testSet.getNumDataPoints(); n++) {
        for (int m = 0; m < testSet.getNumFeatures(); m++) {
            int bit = bernoulliDistribution.sample();
            testSet.setFeatureValue(n, m, bit);
            int flip = bit;
            if (Math.random() < 0.1) {
                flip = 1 - bit;
            }
            if (m == 0) {
                if (flip == 0) {
                    testSet.addLabel(n, 0);
                } else {
                    testSet.addLabel(n, 1);
                }
            } else {
                if (flip == 0) {
                    testSet.addLabel(n, 2);
                } else {
                    testSet.addLabel(n, 3);
                }
            }
        }
    }
    int numComponents = 4;
    CBM cbm = CBM.getBuilder().setNumClasses(dataSet.getNumClasses()).setNumFeatures(dataSet.getNumFeatures()).setNumComponents(numComponents).setBinaryClassifierType("boost").setMultiClassClassifierType("boost").build();
    cbm.setPredictMode("dynamic");
    CBMOptimizer optimizer = new CBMOptimizer(cbm, dataSet);
    optimizer.setPriorVarianceBinary(10);
    optimizer.setPriorVarianceMultiClass(10);
    CBMInitializer.initialize(cbm, dataSet, optimizer);
    for (int i = 0; i < 3; i++) {
        optimizer.iterate();
        System.out.print("i: " + i + "\t");
        System.out.print("objective: " + optimizer.getTerminator().getLastValue() + "\t");
        System.out.print("trainAcc: " + Accuracy.accuracy(cbm, dataSet) + "\t");
        System.out.println("testAcc: " + Accuracy.accuracy(cbm, testSet));
    }
    System.out.println(cbm.toString());
}
Also used : BernoulliDistribution(edu.neu.ccs.pyramid.util.BernoulliDistribution) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 7 with BernoulliDistribution

use of edu.neu.ccs.pyramid.util.BernoulliDistribution in project pyramid by cheng-li.

the class ClusterLabels method getCluster.

private static List<WordFrequency> getCluster(BM bm, int k) throws Exception {
    BernoulliDistribution[][] distributions = bm.getDistributions();
    List<Pair<String, Double>> pairs = new ArrayList<>();
    for (int d = 0; d < bm.getDimension(); d++) {
        Pair<String, Double> pair = new Pair<>(bm.getNames().get(d), distributions[k][d].getP());
        pairs.add(pair);
    }
    Comparator<Pair<String, Double>> comparator = Comparator.comparing(Pair::getSecond);
    List<Pair<String, Double>> sorted = pairs.stream().sorted(comparator.reversed()).collect(Collectors.toList());
    List<WordFrequency> frequencies = new ArrayList<>();
    double sum = sorted.stream().filter(pair -> pair.getSecond() > 0).limit(20).mapToDouble(Pair::getSecond).sum();
    sorted.stream().filter(pair -> pair.getSecond() > 0).limit(20).forEach(pair -> {
        WordFrequency wordFrequency = new WordFrequency(pair.getFirst(), (int) (pair.getSecond() * 200 / sum));
        frequencies.add(wordFrequency);
    });
    return frequencies;
}
Also used : edu.neu.ccs.pyramid.util(edu.neu.ccs.pyramid.util) java.util(java.util) ArgSort(edu.neu.ccs.pyramid.util.ArgSort) CollisionMode(com.kennycason.kumo.CollisionMode) CenterWordStart(com.kennycason.kumo.wordstart.CenterWordStart) Random(java.util.Random) BMTrainer(edu.neu.ccs.pyramid.clustering.bm.BMTrainer) ArrayList(java.util.ArrayList) LinearFontScalar(com.kennycason.kumo.font.scale.LinearFontScalar) RectangleBackground(com.kennycason.kumo.bg.RectangleBackground) WordCloud(com.kennycason.kumo.WordCloud) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) BernoulliDistribution(edu.neu.ccs.pyramid.util.BernoulliDistribution) AngleGenerator(com.kennycason.kumo.image.AngleGenerator) FileUtils(org.apache.commons.io.FileUtils) Collectors(java.util.stream.Collectors) ColorPalette(com.kennycason.kumo.palette.ColorPalette) File(java.io.File) java.awt(java.awt) List(java.util.List) Serialization(edu.neu.ccs.pyramid.util.Serialization) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) WordFrequency(com.kennycason.kumo.WordFrequency) Comparator(java.util.Comparator) BM(edu.neu.ccs.pyramid.clustering.bm.BM) ArrayList(java.util.ArrayList) WordFrequency(com.kennycason.kumo.WordFrequency) Pair(edu.neu.ccs.pyramid.util.Pair)

Aggregations

BernoulliDistribution (edu.neu.ccs.pyramid.util.BernoulliDistribution)7 DenseVector (org.apache.mahout.math.DenseVector)3 Vector (org.apache.mahout.math.Vector)3 ArrayList (java.util.ArrayList)2 CollisionMode (com.kennycason.kumo.CollisionMode)1 WordCloud (com.kennycason.kumo.WordCloud)1 WordFrequency (com.kennycason.kumo.WordFrequency)1 RectangleBackground (com.kennycason.kumo.bg.RectangleBackground)1 LinearFontScalar (com.kennycason.kumo.font.scale.LinearFontScalar)1 AngleGenerator (com.kennycason.kumo.image.AngleGenerator)1 ColorPalette (com.kennycason.kumo.palette.ColorPalette)1 CenterWordStart (com.kennycason.kumo.wordstart.CenterWordStart)1 BM (edu.neu.ccs.pyramid.clustering.bm.BM)1 BMTrainer (edu.neu.ccs.pyramid.clustering.bm.BMTrainer)1 Config (edu.neu.ccs.pyramid.configuration.Config)1 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)1 DataSet (edu.neu.ccs.pyramid.dataset.DataSet)1 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)1 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)1 RegDataSet (edu.neu.ccs.pyramid.dataset.RegDataSet)1