Search in sources :

Example 1 with BM

use of edu.neu.ccs.pyramid.clustering.bm.BM in project pyramid by cheng-li.

the class ClusterLabels method fitModel.

private static void fitModel(Config config) throws Exception {
    List<String> labelNames = FileUtils.readLines(new File(config.getString("input.labelNames")));
    List<String> labels = FileUtils.readLines(new File(config.getString("input.labels")));
    int numLabels = labelNames.size();
    int numData = labels.size();
    DataSet dataSet = DataSetBuilder.getBuilder().density(Density.SPARSE_SEQUENTIAL).numDataPoints(numData).numFeatures(numLabels).build();
    for (int i = 0; i < numData; i++) {
        String line = labels.get(i);
        if (!line.isEmpty()) {
            String[] split = line.split(" ");
            for (String s : split) {
                int l = Integer.parseInt(s);
                dataSet.setFeatureValue(i, l, 1);
            }
        }
    }
    System.out.println("data loaded");
    int numClusters = config.getInt("numClusters");
    System.out.println("Start training Bernoulli mixture with EM");
    BMTrainer trainer = new BMTrainer(dataSet, numClusters, 0);
    for (int iter = 1; iter <= config.getInt("numIterations"); iter++) {
        System.out.println("iteration = " + iter);
        trainer.eStep();
        trainer.mStep();
    //            if (iter%5==0){
    //                System.out.println("obj = "+trainer.getObjective());
    //            }
    }
    BM bm = trainer.getBm();
    bm.setNames(labelNames);
    String output = config.getString("output.dir");
    new File(output).mkdirs();
    Serialization.serialize(bm, new File(output, "model"));
    FileUtils.writeStringToFile(new File(output, "model_parameters.txt"), bm.toString());
}
Also used : BMTrainer(edu.neu.ccs.pyramid.clustering.bm.BMTrainer) BM(edu.neu.ccs.pyramid.clustering.bm.BM) File(java.io.File)

Example 2 with BM

use of edu.neu.ccs.pyramid.clustering.bm.BM in project pyramid by cheng-li.

the class ClusterLabels method getCluster.

private static List<WordFrequency> getCluster(BM bm, int k) throws Exception {
    BernoulliDistribution[][] distributions = bm.getDistributions();
    List<Pair<String, Double>> pairs = new ArrayList<>();
    for (int d = 0; d < bm.getDimension(); d++) {
        Pair<String, Double> pair = new Pair<>(bm.getNames().get(d), distributions[k][d].getP());
        pairs.add(pair);
    }
    Comparator<Pair<String, Double>> comparator = Comparator.comparing(Pair::getSecond);
    List<Pair<String, Double>> sorted = pairs.stream().sorted(comparator.reversed()).collect(Collectors.toList());
    List<WordFrequency> frequencies = new ArrayList<>();
    double sum = sorted.stream().filter(pair -> pair.getSecond() > 0).limit(20).mapToDouble(Pair::getSecond).sum();
    sorted.stream().filter(pair -> pair.getSecond() > 0).limit(20).forEach(pair -> {
        WordFrequency wordFrequency = new WordFrequency(pair.getFirst(), (int) (pair.getSecond() * 200 / sum));
        frequencies.add(wordFrequency);
    });
    return frequencies;
}
Also used : edu.neu.ccs.pyramid.util(edu.neu.ccs.pyramid.util) java.util(java.util) ArgSort(edu.neu.ccs.pyramid.util.ArgSort) CollisionMode(com.kennycason.kumo.CollisionMode) CenterWordStart(com.kennycason.kumo.wordstart.CenterWordStart) Random(java.util.Random) BMTrainer(edu.neu.ccs.pyramid.clustering.bm.BMTrainer) ArrayList(java.util.ArrayList) LinearFontScalar(com.kennycason.kumo.font.scale.LinearFontScalar) RectangleBackground(com.kennycason.kumo.bg.RectangleBackground) WordCloud(com.kennycason.kumo.WordCloud) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) BernoulliDistribution(edu.neu.ccs.pyramid.util.BernoulliDistribution) AngleGenerator(com.kennycason.kumo.image.AngleGenerator) FileUtils(org.apache.commons.io.FileUtils) Collectors(java.util.stream.Collectors) ColorPalette(com.kennycason.kumo.palette.ColorPalette) File(java.io.File) java.awt(java.awt) List(java.util.List) Serialization(edu.neu.ccs.pyramid.util.Serialization) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) WordFrequency(com.kennycason.kumo.WordFrequency) Comparator(java.util.Comparator) BM(edu.neu.ccs.pyramid.clustering.bm.BM) ArrayList(java.util.ArrayList) WordFrequency(com.kennycason.kumo.WordFrequency) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 3 with BM

use of edu.neu.ccs.pyramid.clustering.bm.BM in project pyramid by cheng-li.

the class ClusterLabels method plot.

public static void plot(Config config) throws Exception {
    BM bm = (BM) Serialization.deserialize(new File(config.getString("output.dir"), "model"));
    double[] coefficients = bm.getMixtureCoefficients();
    int[] sortedComponents = ArgSort.argSortDescending(bm.getMixtureCoefficients());
    File clusterFolder = Paths.get(config.getString("output.dir"), "clusters").toFile();
    clusterFolder.mkdirs();
    FileUtils.cleanDirectory(clusterFolder);
    for (int i = 0; i < sortedComponents.length; i++) {
        int k = sortedComponents[i];
        List<WordFrequency> frequencies = getCluster(bm, k);
        double max = frequencies.stream().mapToDouble(WordFrequency::getFrequency).max().getAsDouble();
        double sum = frequencies.stream().mapToDouble(WordFrequency::getFrequency).sum();
        double ratio = sum / max;
        final Dimension dimension = new Dimension(600, 600);
        final WordCloud wordCloud = new WordCloud(dimension, CollisionMode.RECTANGLE);
        wordCloud.setPadding(0);
        wordCloud.setAngleGenerator(new AngleGenerator(0));
        wordCloud.setBackground(new RectangleBackground(dimension));
        wordCloud.setColorPalette(buildRandomColorPalette(20));
        wordCloud.setBackgroundColor(Color.WHITE);
        wordCloud.setFontScalar(new LinearFontScalar(20, (int) (500 / ratio)));
        wordCloud.setWordStartStrategy(new CenterWordStart());
        wordCloud.build(frequencies);
        File out = Paths.get(config.getString("output.dir"), "clusters", "" + i + "_" + coefficients[k] + ".png").toFile();
        wordCloud.writeToFile(out.getAbsolutePath());
    }
}
Also used : WordFrequency(com.kennycason.kumo.WordFrequency) LinearFontScalar(com.kennycason.kumo.font.scale.LinearFontScalar) BM(edu.neu.ccs.pyramid.clustering.bm.BM) RectangleBackground(com.kennycason.kumo.bg.RectangleBackground) AngleGenerator(com.kennycason.kumo.image.AngleGenerator) WordCloud(com.kennycason.kumo.WordCloud) File(java.io.File) CenterWordStart(com.kennycason.kumo.wordstart.CenterWordStart)

Aggregations

BM (edu.neu.ccs.pyramid.clustering.bm.BM)3 File (java.io.File)3 WordCloud (com.kennycason.kumo.WordCloud)2 WordFrequency (com.kennycason.kumo.WordFrequency)2 RectangleBackground (com.kennycason.kumo.bg.RectangleBackground)2 LinearFontScalar (com.kennycason.kumo.font.scale.LinearFontScalar)2 AngleGenerator (com.kennycason.kumo.image.AngleGenerator)2 CenterWordStart (com.kennycason.kumo.wordstart.CenterWordStart)2 BMTrainer (edu.neu.ccs.pyramid.clustering.bm.BMTrainer)2 CollisionMode (com.kennycason.kumo.CollisionMode)1 ColorPalette (com.kennycason.kumo.palette.ColorPalette)1 Config (edu.neu.ccs.pyramid.configuration.Config)1 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)1 edu.neu.ccs.pyramid.util (edu.neu.ccs.pyramid.util)1 ArgSort (edu.neu.ccs.pyramid.util.ArgSort)1 BernoulliDistribution (edu.neu.ccs.pyramid.util.BernoulliDistribution)1 Pair (edu.neu.ccs.pyramid.util.Pair)1 Serialization (edu.neu.ccs.pyramid.util.Serialization)1 java.awt (java.awt)1 Paths (java.nio.file.Paths)1