Search in sources :

Example 1 with BMTrainer

use of edu.neu.ccs.pyramid.clustering.bm.BMTrainer in project pyramid by cheng-li.

the class ClusterLabels method fitModel.

private static void fitModel(Config config) throws Exception {
    List<String> labelNames = FileUtils.readLines(new File(config.getString("input.labelNames")));
    List<String> labels = FileUtils.readLines(new File(config.getString("input.labels")));
    int numLabels = labelNames.size();
    int numData = labels.size();
    DataSet dataSet = DataSetBuilder.getBuilder().density(Density.SPARSE_SEQUENTIAL).numDataPoints(numData).numFeatures(numLabels).build();
    for (int i = 0; i < numData; i++) {
        String line = labels.get(i);
        if (!line.isEmpty()) {
            String[] split = line.split(" ");
            for (String s : split) {
                int l = Integer.parseInt(s);
                dataSet.setFeatureValue(i, l, 1);
            }
        }
    }
    System.out.println("data loaded");
    int numClusters = config.getInt("numClusters");
    System.out.println("Start training Bernoulli mixture with EM");
    BMTrainer trainer = new BMTrainer(dataSet, numClusters, 0);
    for (int iter = 1; iter <= config.getInt("numIterations"); iter++) {
        System.out.println("iteration = " + iter);
        trainer.eStep();
        trainer.mStep();
    //            if (iter%5==0){
    //                System.out.println("obj = "+trainer.getObjective());
    //            }
    }
    BM bm = trainer.getBm();
    bm.setNames(labelNames);
    String output = config.getString("output.dir");
    new File(output).mkdirs();
    Serialization.serialize(bm, new File(output, "model"));
    FileUtils.writeStringToFile(new File(output, "model_parameters.txt"), bm.toString());
}
Also used : BMTrainer(edu.neu.ccs.pyramid.clustering.bm.BMTrainer) BM(edu.neu.ccs.pyramid.clustering.bm.BM) File(java.io.File)

Aggregations

BM (edu.neu.ccs.pyramid.clustering.bm.BM)1 BMTrainer (edu.neu.ccs.pyramid.clustering.bm.BMTrainer)1 File (java.io.File)1