use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.
the class KMeansTest method fashion.
private static void fashion() throws Exception {
FileUtils.cleanDirectory(new File("/Users/chengli/tmp/kmeans_demo"));
List<String> lines = FileUtils.readLines(new File("/Users/chengli/Dropbox/Shared/CS6220DM/data/fashion/features.txt"));
Collections.shuffle(lines, new Random(0));
int rows = 100;
DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(rows).numFeatures(28 * 28).build();
for (int i = 0; i < rows; i++) {
String line = lines.get(i);
String[] split = line.split(",");
for (int j = 0; j < split.length; j++) {
dataSet.setFeatureValue(i, j, Double.parseDouble(split[j]));
}
}
int numComponents = 10;
KMeans kMeans = new KMeans(numComponents, dataSet);
// kMeans.randomInitialize();
kMeans.kmeansPlusPlusInitialize(100);
List<Double> objectives = new ArrayList<>();
boolean showInitialize = true;
if (showInitialize) {
int[] assignment = kMeans.getAssignments();
for (int k = 0; k < numComponents; k++) {
plot(kMeans.getCenters()[k], 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_" + (k + 1) + "/center.png");
// plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"_component_"+(k+1)+"_pic_000center.png");
int counter = 0;
for (int i = 0; i < assignment.length; i++) {
if (assignment[i] == k) {
plot(dataSet.getRow(i), 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_" + (k + 1) + "/pic_" + (i + 1) + ".png");
counter += 1;
}
// if (counter==5){
// break;
// }
}
}
}
objectives.add(kMeans.objective());
for (int iter = 1; iter <= 5; iter++) {
System.out.println("=====================================");
System.out.println("iteration " + iter);
kMeans.iterate();
objectives.add(kMeans.objective());
int[] assignment = kMeans.getAssignments();
for (int k = 0; k < numComponents; k++) {
plot(kMeans.getCenters()[k], 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_" + iter + "/cluster_" + (k + 1) + "/center.png");
// plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"_component_"+(k+1)+"_pic_000center.png");
int counter = 0;
for (int i = 0; i < assignment.length; i++) {
if (assignment[i] == k) {
plot(dataSet.getRow(i), 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_" + iter + "/cluster_" + (k + 1) + "/pic_" + (i + 1) + ".png");
counter += 1;
}
// if (counter==5){
// break;
// }
}
}
System.out.println("training objective changes: " + objectives);
}
}
use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.
the class KMeansTest method mnistSubgroup.
private static void mnistSubgroup(int label) throws Exception {
FileUtils.cleanDirectory(new File("/Users/chengli/tmp/kmeans_demo"));
List<Integer> labels = FileUtils.readLines(new File("/Users/chengli/Dropbox/Shared/CS6220DM/2_cluster_EM_mixt/HW2/mnist_labels.txt")).stream().mapToInt(l -> (int) Double.parseDouble(l)).boxed().collect(Collectors.toList());
List<String> features = FileUtils.readLines(new File("/Users/chengli/Dropbox/Shared/CS6220DM/2_cluster_EM_mixt/HW2/mnist_features.txt"));
List<String> lines = IntStream.range(0, features.size()).filter(i -> labels.get(i) == label).mapToObj(i -> features.get(i)).collect(Collectors.toList());
int rows = 100;
DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(rows).numFeatures(28 * 28).build();
for (int i = 0; i < rows; i++) {
String line = lines.get(i);
String[] split = line.split(" ");
for (int j = 0; j < split.length; j++) {
dataSet.setFeatureValue(i, j, Double.parseDouble(split[j]));
}
}
int numComponents = 10;
KMeans kMeans = new KMeans(numComponents, dataSet);
// kMeans.randomInitialize();
kMeans.kmeansPlusPlusInitialize(100);
List<Double> objectives = new ArrayList<>();
boolean showInitialize = true;
if (showInitialize) {
int[] assignment = kMeans.getAssignments();
for (int k = 0; k < numComponents; k++) {
plot(kMeans.getCenters()[k], 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_" + (k + 1) + "/center.png");
// plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"_component_"+(k+1)+"_pic_000center.png");
int counter = 0;
for (int i = 0; i < assignment.length; i++) {
if (assignment[i] == k) {
plot(dataSet.getRow(i), 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_" + (k + 1) + "/pic_" + (i + 1) + ".png");
counter += 1;
}
// if (counter==5){
// break;
// }
}
}
}
objectives.add(kMeans.objective());
for (int iter = 1; iter <= 5; iter++) {
System.out.println("=====================================");
System.out.println("iteration " + iter);
kMeans.iterate();
objectives.add(kMeans.objective());
int[] assignment = kMeans.getAssignments();
for (int k = 0; k < numComponents; k++) {
plot(kMeans.getCenters()[k], 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_" + iter + "/cluster_" + (k + 1) + "/center.png");
// plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"_component_"+(k+1)+"_pic_000center.png");
int counter = 0;
for (int i = 0; i < assignment.length; i++) {
if (assignment[i] == k) {
plot(dataSet.getRow(i), 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_" + iter + "/cluster_" + (k + 1) + "/pic_" + (i + 1) + ".png");
counter += 1;
}
// if (counter==5){
// break;
// }
}
}
System.out.println("training objective changes: " + objectives);
}
}
use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.
the class BMTrainerTest method test6.
private static void test6() {
DataSet dataSet = DataSetBuilder.getBuilder().numFeatures(5).numDataPoints(20).dense(true).build();
for (int i = 0; i < 5; i++) {
dataSet.setFeatureValue(i, 0, 1);
}
for (int i = 5; i < 10; i++) {
dataSet.setFeatureValue(i, 1, 1);
}
for (int i = 10; i < 20; i++) {
dataSet.setFeatureValue(i, 1, 1);
dataSet.setFeatureValue(i, 2, 1);
dataSet.setFeatureValue(i, 3, 1);
}
System.out.println("dataset = " + dataSet);
BMTrainer trainer = new BMTrainer(dataSet, 3, 0);
System.out.println(trainer.bm);
BM bm = trainer.train();
// for (int iter=0;iter<100;iter++){
// trainer.iterate();
// }
System.out.println(bm);
for (int i = 0; i < 5; i++) {
System.out.println("sample " + i);
System.out.println(bm.sample());
}
}
use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.
the class BMTrainerTest method test3.
private static void test3() {
DataSet dataSet = DataSetBuilder.getBuilder().numFeatures(1).numDataPoints(2).dense(true).build();
dataSet.setFeatureValue(0, 0, 1);
System.out.println("dataset = " + dataSet);
BMTrainer trainer = new BMTrainer(dataSet, 2, 0);
System.out.println(trainer.bm);
BM bm = trainer.train();
System.out.println(bm);
}
use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.
the class BMTrainerTest method test1.
private static void test1() {
DataSet dataSet = DataSetBuilder.getBuilder().numFeatures(1).numDataPoints(3).dense(true).build();
dataSet.setFeatureValue(0, 0, 1);
dataSet.setFeatureValue(1, 0, 1);
System.out.println("dataset = " + dataSet);
BMTrainer trainer = new BMTrainer(dataSet, 1, 0);
System.out.println(trainer.bm);
BM bm = trainer.train();
System.out.println(bm);
}
Aggregations