use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.
the class BMTrainerTest method test5.
private static void test5() {
DataSet dataSet = DataSetBuilder.getBuilder().numFeatures(5).numDataPoints(20).dense(true).build();
for (int i = 0; i < 5; i++) {
dataSet.setFeatureValue(i, 0, 1);
}
for (int i = 5; i < 10; i++) {
dataSet.setFeatureValue(i, 1, 1);
}
for (int i = 10; i < 20; i++) {
dataSet.setFeatureValue(i, 2, 1);
dataSet.setFeatureValue(i, 3, 1);
}
System.out.println("dataset = " + dataSet);
BMTrainer trainer = new BMTrainer(dataSet, 3, 0);
System.out.println(trainer.bm);
trainer.train();
// for (int iter=0;iter<100;iter++){
// trainer.iterate();
// }
System.out.println(trainer.bm);
Vector vector1 = new DenseVector(5);
vector1.set(0, 1);
Vector vector2 = new DenseVector(5);
vector2.set(1, 1);
Vector vector3 = new DenseVector(5);
vector3.set(2, 1);
vector3.set(3, 1);
System.out.println(Math.exp(trainer.bm.logProbability(vector1)));
System.out.println(Math.exp(trainer.bm.logProbability(vector2)));
System.out.println(Math.exp(trainer.bm.logProbability(vector3)));
}
use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.
the class BMTrainerTest method test2.
private static void test2() {
DataSet dataSet = DataSetBuilder.getBuilder().numFeatures(1).numDataPoints(2).dense(true).build();
dataSet.setFeatureValue(0, 0, 1);
System.out.println("dataset = " + dataSet);
BMTrainer trainer = new BMTrainer(dataSet, 1, 0);
System.out.println(trainer.bm);
BM bm = trainer.train();
System.out.println(bm);
}
use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.
the class BMTrainerTest method test4.
private static void test4() {
DataSet dataSet = DataSetBuilder.getBuilder().numFeatures(2).numDataPoints(10).dense(true).build();
for (int i = 0; i < 5; i++) {
dataSet.setFeatureValue(i, 0, 1);
}
for (int i = 5; i < 10; i++) {
dataSet.setFeatureValue(i, 1, 1);
}
System.out.println("dataset = " + dataSet);
BMTrainer trainer = new BMTrainer(dataSet, 2, 0);
System.out.println(trainer.bm);
trainer.train();
// for (int iter=0;iter<100;iter++){
// trainer.iterate();
// }
System.out.println(trainer.bm);
}
use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.
the class KMeansTest method mnist.
private static void mnist() throws Exception {
FileUtils.cleanDirectory(new File("/Users/chengli/tmp/kmeans_demo"));
List<String> lines = FileUtils.readLines(new File("/Users/chengli/Dropbox/Shared/CS6220DM/2_cluster_EM_mixt/HW2/mnist/mnist_features.txt"));
List<Integer> intputLabels = FileUtils.readLines(new File("/Users/chengli/Dropbox/Shared/CS6220DM/2_cluster_EM_mixt/HW2/mnist/mnist_labels.txt")).stream().mapToInt(line -> (int) Double.parseDouble(line)).boxed().collect(Collectors.toList());
Collections.shuffle(lines, new Random(0));
Collections.shuffle(intputLabels, new Random(0));
int rows = 100;
int[] labels = intputLabels.stream().limit(rows).mapToInt(a -> a).toArray();
DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(rows).numFeatures(28 * 28).build();
for (int i = 0; i < rows; i++) {
String line = lines.get(i);
String[] split = line.split(" ");
for (int j = 0; j < split.length; j++) {
dataSet.setFeatureValue(i, j, Double.parseDouble(split[j]));
}
}
int numComponents = 20;
KMeans kMeans = new KMeans(numComponents, dataSet);
// kMeans.randomInitialize();
kMeans.kmeansPlusPlusInitialize(100);
List<Double> objectives = new ArrayList<>();
boolean showInitialize = true;
if (showInitialize) {
int[] assignment = kMeans.getAssignments();
for (int k = 0; k < numComponents; k++) {
plot(kMeans.getCenters()[k], 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_" + (k + 1) + "/center.png");
// plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"_component_"+(k+1)+"_pic_000center.png");
int counter = 0;
for (int i = 0; i < assignment.length; i++) {
if (assignment[i] == k) {
plot(dataSet.getRow(i), 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_" + (k + 1) + "/pic_" + (i + 1) + ".png");
counter += 1;
}
// if (counter==5){
// break;
// }
}
}
}
objectives.add(kMeans.objective());
for (int iter = 1; iter <= 5; iter++) {
System.out.println("=====================================");
System.out.println("iteration " + iter);
kMeans.iterate();
objectives.add(kMeans.objective());
int[] assignment = kMeans.getAssignments();
for (int k = 0; k < numComponents; k++) {
plot(kMeans.getCenters()[k], 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_" + iter + "/cluster_" + (k + 1) + "/center.png");
// plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"_component_"+(k+1)+"_pic_000center.png");
int counter = 0;
for (int i = 0; i < assignment.length; i++) {
if (assignment[i] == k) {
plot(dataSet.getRow(i), 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_" + iter + "/cluster_" + (k + 1) + "/pic_" + (i + 1) + ".png");
counter += 1;
}
// if (counter==5){
// break;
// }
}
}
}
int[] assignment = kMeans.getAssignments();
System.out.println("purity=" + purity(assignment, labels, 10, numComponents));
System.out.println("training objective changes: " + objectives);
}
use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.
the class KMeansTest method extractFashionImages.
private static void extractFashionImages() throws Exception {
List<String> lines = FileUtils.readLines(new File("/Users/chengli/Dropbox/Shared/CS6220DM/data/fashion/features.txt"));
Collections.shuffle(lines, new Random(0));
int rows = 100;
DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(rows).numFeatures(28 * 28).build();
for (int i = 0; i < rows; i++) {
String line = lines.get(i);
String[] split = line.split(",");
for (int j = 0; j < split.length; j++) {
dataSet.setFeatureValue(i, j, Double.parseDouble(split[j]));
}
}
for (int i = 0; i < rows; i++) {
plot(dataSet.getRow(i), 28, 28, "/Users/chengli/tmp/fashion/pic_" + (i + 1) + ".png");
}
}
Aggregations