use of edu.neu.ccs.pyramid.util.BernoulliDistribution in project pyramid by cheng-li.
the class CBMTest method test2.
private static void test2() throws Exception {
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(2).numClasses(4).numDataPoints(1000).build();
BernoulliDistribution bernoulliDistribution = new BernoulliDistribution(0.5);
for (int n = 0; n < dataSet.getNumDataPoints(); n++) {
for (int m = 0; m < dataSet.getNumFeatures(); m++) {
int bit = bernoulliDistribution.sample();
int flip = bit;
if (Math.random() < 0.1) {
flip = 1 - bit;
}
dataSet.setFeatureValue(n, m, bit);
if (m == 0) {
if (flip == 0) {
dataSet.addLabel(n, 0);
} else {
dataSet.addLabel(n, 1);
}
} else {
if (flip == 0) {
dataSet.addLabel(n, 2);
} else {
dataSet.addLabel(n, 3);
}
}
}
}
MultiLabelClfDataSet testSet = MLClfDataSetBuilder.getBuilder().numFeatures(2).numClasses(4).numDataPoints(100).build();
for (int n = 0; n < testSet.getNumDataPoints(); n++) {
for (int m = 0; m < testSet.getNumFeatures(); m++) {
int bit = bernoulliDistribution.sample();
testSet.setFeatureValue(n, m, bit);
int flip = bit;
if (Math.random() < 0.1) {
flip = 1 - bit;
}
if (m == 0) {
if (flip == 0) {
testSet.addLabel(n, 0);
} else {
testSet.addLabel(n, 1);
}
} else {
if (flip == 0) {
testSet.addLabel(n, 2);
} else {
testSet.addLabel(n, 3);
}
}
}
}
int numComponents = 4;
CBM cbm = CBM.getBuilder().setNumClasses(dataSet.getNumClasses()).setNumFeatures(dataSet.getNumFeatures()).setNumComponents(numComponents).setBinaryClassifierType("boost").setMultiClassClassifierType("boost").build();
cbm.setPredictMode("dynamic");
CBMOptimizer optimizer = new CBMOptimizer(cbm, dataSet);
optimizer.setPriorVarianceBinary(10);
optimizer.setPriorVarianceMultiClass(10);
CBMInitializer.initialize(cbm, dataSet, optimizer);
for (int i = 0; i < 3; i++) {
optimizer.iterate();
System.out.print("i: " + i + "\t");
System.out.print("objective: " + optimizer.getTerminator().getLastValue() + "\t");
System.out.print("trainAcc: " + Accuracy.accuracy(cbm, dataSet) + "\t");
System.out.println("testAcc: " + Accuracy.accuracy(cbm, testSet));
}
System.out.println(cbm.toString());
}
use of edu.neu.ccs.pyramid.util.BernoulliDistribution in project pyramid by cheng-li.
the class ClusterLabels method getCluster.
private static List<WordFrequency> getCluster(BM bm, int k) throws Exception {
BernoulliDistribution[][] distributions = bm.getDistributions();
List<Pair<String, Double>> pairs = new ArrayList<>();
for (int d = 0; d < bm.getDimension(); d++) {
Pair<String, Double> pair = new Pair<>(bm.getNames().get(d), distributions[k][d].getP());
pairs.add(pair);
}
Comparator<Pair<String, Double>> comparator = Comparator.comparing(Pair::getSecond);
List<Pair<String, Double>> sorted = pairs.stream().sorted(comparator.reversed()).collect(Collectors.toList());
List<WordFrequency> frequencies = new ArrayList<>();
double sum = sorted.stream().filter(pair -> pair.getSecond() > 0).limit(20).mapToDouble(Pair::getSecond).sum();
sorted.stream().filter(pair -> pair.getSecond() > 0).limit(20).forEach(pair -> {
WordFrequency wordFrequency = new WordFrequency(pair.getFirst(), (int) (pair.getSecond() * 200 / sum));
frequencies.add(wordFrequency);
});
return frequencies;
}
Aggregations