use of edu.neu.ccs.pyramid.clustering.bm.BMTrainer in project pyramid by cheng-li.
the class ClusterLabels method fitModel.
private static void fitModel(Config config) throws Exception {
List<String> labelNames = FileUtils.readLines(new File(config.getString("input.labelNames")));
List<String> labels = FileUtils.readLines(new File(config.getString("input.labels")));
int numLabels = labelNames.size();
int numData = labels.size();
DataSet dataSet = DataSetBuilder.getBuilder().density(Density.SPARSE_SEQUENTIAL).numDataPoints(numData).numFeatures(numLabels).build();
for (int i = 0; i < numData; i++) {
String line = labels.get(i);
if (!line.isEmpty()) {
String[] split = line.split(" ");
for (String s : split) {
int l = Integer.parseInt(s);
dataSet.setFeatureValue(i, l, 1);
}
}
}
System.out.println("data loaded");
int numClusters = config.getInt("numClusters");
System.out.println("Start training Bernoulli mixture with EM");
BMTrainer trainer = new BMTrainer(dataSet, numClusters, 0);
for (int iter = 1; iter <= config.getInt("numIterations"); iter++) {
System.out.println("iteration = " + iter);
trainer.eStep();
trainer.mStep();
// if (iter%5==0){
// System.out.println("obj = "+trainer.getObjective());
// }
}
BM bm = trainer.getBm();
bm.setNames(labelNames);
String output = config.getString("output.dir");
new File(output).mkdirs();
Serialization.serialize(bm, new File(output, "model"));
FileUtils.writeStringToFile(new File(output, "model_parameters.txt"), bm.toString());
}