use of edu.neu.ccs.pyramid.clustering.bm.BM in project pyramid by cheng-li.
the class ClusterLabels method fitModel.
private static void fitModel(Config config) throws Exception {
List<String> labelNames = FileUtils.readLines(new File(config.getString("input.labelNames")));
List<String> labels = FileUtils.readLines(new File(config.getString("input.labels")));
int numLabels = labelNames.size();
int numData = labels.size();
DataSet dataSet = DataSetBuilder.getBuilder().density(Density.SPARSE_SEQUENTIAL).numDataPoints(numData).numFeatures(numLabels).build();
for (int i = 0; i < numData; i++) {
String line = labels.get(i);
if (!line.isEmpty()) {
String[] split = line.split(" ");
for (String s : split) {
int l = Integer.parseInt(s);
dataSet.setFeatureValue(i, l, 1);
}
}
}
System.out.println("data loaded");
int numClusters = config.getInt("numClusters");
System.out.println("Start training Bernoulli mixture with EM");
BMTrainer trainer = new BMTrainer(dataSet, numClusters, 0);
for (int iter = 1; iter <= config.getInt("numIterations"); iter++) {
System.out.println("iteration = " + iter);
trainer.eStep();
trainer.mStep();
// if (iter%5==0){
// System.out.println("obj = "+trainer.getObjective());
// }
}
BM bm = trainer.getBm();
bm.setNames(labelNames);
String output = config.getString("output.dir");
new File(output).mkdirs();
Serialization.serialize(bm, new File(output, "model"));
FileUtils.writeStringToFile(new File(output, "model_parameters.txt"), bm.toString());
}
use of edu.neu.ccs.pyramid.clustering.bm.BM in project pyramid by cheng-li.
the class ClusterLabels method getCluster.
private static List<WordFrequency> getCluster(BM bm, int k) throws Exception {
BernoulliDistribution[][] distributions = bm.getDistributions();
List<Pair<String, Double>> pairs = new ArrayList<>();
for (int d = 0; d < bm.getDimension(); d++) {
Pair<String, Double> pair = new Pair<>(bm.getNames().get(d), distributions[k][d].getP());
pairs.add(pair);
}
Comparator<Pair<String, Double>> comparator = Comparator.comparing(Pair::getSecond);
List<Pair<String, Double>> sorted = pairs.stream().sorted(comparator.reversed()).collect(Collectors.toList());
List<WordFrequency> frequencies = new ArrayList<>();
double sum = sorted.stream().filter(pair -> pair.getSecond() > 0).limit(20).mapToDouble(Pair::getSecond).sum();
sorted.stream().filter(pair -> pair.getSecond() > 0).limit(20).forEach(pair -> {
WordFrequency wordFrequency = new WordFrequency(pair.getFirst(), (int) (pair.getSecond() * 200 / sum));
frequencies.add(wordFrequency);
});
return frequencies;
}
use of edu.neu.ccs.pyramid.clustering.bm.BM in project pyramid by cheng-li.
the class ClusterLabels method plot.
public static void plot(Config config) throws Exception {
BM bm = (BM) Serialization.deserialize(new File(config.getString("output.dir"), "model"));
double[] coefficients = bm.getMixtureCoefficients();
int[] sortedComponents = ArgSort.argSortDescending(bm.getMixtureCoefficients());
File clusterFolder = Paths.get(config.getString("output.dir"), "clusters").toFile();
clusterFolder.mkdirs();
FileUtils.cleanDirectory(clusterFolder);
for (int i = 0; i < sortedComponents.length; i++) {
int k = sortedComponents[i];
List<WordFrequency> frequencies = getCluster(bm, k);
double max = frequencies.stream().mapToDouble(WordFrequency::getFrequency).max().getAsDouble();
double sum = frequencies.stream().mapToDouble(WordFrequency::getFrequency).sum();
double ratio = sum / max;
final Dimension dimension = new Dimension(600, 600);
final WordCloud wordCloud = new WordCloud(dimension, CollisionMode.RECTANGLE);
wordCloud.setPadding(0);
wordCloud.setAngleGenerator(new AngleGenerator(0));
wordCloud.setBackground(new RectangleBackground(dimension));
wordCloud.setColorPalette(buildRandomColorPalette(20));
wordCloud.setBackgroundColor(Color.WHITE);
wordCloud.setFontScalar(new LinearFontScalar(20, (int) (500 / ratio)));
wordCloud.setWordStartStrategy(new CenterWordStart());
wordCloud.build(frequencies);
File out = Paths.get(config.getString("output.dir"), "clusters", "" + i + "_" + coefficients[k] + ".png").toFile();
wordCloud.writeToFile(out.getAbsolutePath());
}
}
Aggregations