use of edu.neu.ccs.pyramid.clustering.kmeans.KMeans in project pyramid by cheng-li.
the class GMMTrainerTest method fashion.
private static void fashion() throws Exception {
FileUtils.cleanDirectory(new File("/Users/chengli/tmp/kmeans_demo"));
List<String> lines = FileUtils.readLines(new File("/Users/chengli/Dropbox/Shared/CS6220DM/data/fashion/features.txt"));
Collections.shuffle(lines, new Random(0));
int rows = 100;
DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(rows).numFeatures(28 * 28).build();
for (int i = 0; i < rows; i++) {
String line = lines.get(i);
String[] split = line.split(",");
for (int j = 0; j < split.length; j++) {
dataSet.setFeatureValue(i, j, Double.parseDouble(split[j]) / 255);
}
}
int numComponents = 5;
// KMeans kMeans = new KMeans(numComponents, dataSet);
// // kMeans.randomInitialize();
// kMeans.kmeansPlusPlusInitialize(100);
// List<Double> objectives = new ArrayList<>();
// boolean showInitialize = true;
// if (showInitialize){
// int[] assignment = kMeans.getAssignments();
// for (int k=0;k<numComponents;k++){
// plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_"+(k+1)+"/center.png");
// // plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"_component_"+(k+1)+"_pic_000center.png");
//
// int counter = 0;
// for (int i=0;i<assignment.length;i++){
// if (assignment[i]==k){
// plot(dataSet.getRow(i), 28,28,
// "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_"+(k+1)+"/pic_"+(i+1)+".png");
// counter+=1;
// }
//
// }
// }
// }
// objectives.add(kMeans.objective());
//
//
// for (int iter=1;iter<=5;iter++){
// System.out.println("=====================================");
// System.out.println("iteration "+iter);
// kMeans.iterate();
// objectives.add(kMeans.objective());
// int[] assignment = kMeans.getAssignments();
// for (int k=0;k<numComponents;k++){
// plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"/cluster_"+(k+1)+"/center.png");
// for (int i=0;i<assignment.length;i++){
// if (assignment[i]==k){
// plot(dataSet.getRow(i), 28,28,
// "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"/cluster_"+(k+1)+"/pic_"+(i+1)+".png");
// }
// }
// }
//
// System.out.println("training objective changes: "+objectives);
// }
// int[] assignments = kMeans.getAssignments();
RealMatrix data = new Array2DRowRealMatrix(rows, dataSet.getNumFeatures());
for (int i = 0; i < rows; i++) {
for (int j = 0; j < dataSet.getNumFeatures(); j++) {
data.setEntry(i, j, dataSet.getRow(i).get(j));
}
}
GMM gmm = new GMM(dataSet.getNumFeatures(), numComponents, data);
GMMTrainer trainer = new GMMTrainer(data, gmm);
// double[][] gammas = new double[assignments.length][numComponents];
// for (int i=0;i<assignments.length;i++){
// gammas[i][assignments[i]]=1;
// }
// trainer.setGammas(gammas);
System.out.println("start training GMM");
for (int i = 1; i <= 5; i++) {
// trainer.mStep();
// trainer.eStep();
trainer.iterate();
System.out.println("iteration " + i);
// double[] entropies = IntStream.range(0,rows).mapToDouble(i->Entropy.entropy(gammas[i])).toArray();
// System.out.println(Arrays.toString(entropies));
// int max = ArgMax.argMax(entropies);
// plot(dataSet.getRow(max), 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/max_entropy.png");
// System.out.println(Arrays.toString(gammas[max]));
double logLikelihood = IntStream.range(0, rows).parallel().mapToDouble(j -> gmm.logDensity(data.getRowVector(j))).sum();
System.out.println("log likelihood = " + logLikelihood);
for (int k = 0; k < numComponents; k++) {
plot(gmm.getGaussianDistributions()[k].getMean(), 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_" + i + "/cluster_" + (k + 1) + "/center.png");
// for (int i=0;i<assignment.length;i++){
// if (assignment[i]==k){
// plot(dataSet.getRow(i), 28,28,
// "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"/cluster_"+(k+1)+"/pic_"+(i+1)+".png");
// }
// }
}
}
for (int k = 0; k < numComponents; k++) {
System.out.println("component " + k);
System.out.println("mean=" + gmm.getGaussianDistributions()[k].getMean());
System.out.println("log determinant =" + gmm.getGaussianDistributions()[k].getLogDeterminant());
}
// double[][] gammas = trainer.getGammas();
// double[] entropies = IntStream.range(0,rows).mapToDouble(i->Entropy.entropy(gammas[i])).toArray();
// System.out.println(Arrays.toString(entropies));
// int max = ArgMax.argMax(entropies);
// plot(dataSet.getRow(max), 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/max_entropy.png");
// System.out.println(Arrays.toString(gammas[max]));
// System.out.println(gmm);
}
use of edu.neu.ccs.pyramid.clustering.kmeans.KMeans in project pyramid by cheng-li.
the class GMMTrainerTest method spam.
private static void spam() throws Exception {
FileUtils.cleanDirectory(new File("/Users/chengli/tmp/kmeans_demo"));
DataSet dataSet = TRECFormat.loadClfDataSet("/Users/chengli/tmp/spam/train", DataSetType.CLF_DENSE, true);
// DataSet dataSet = TRECFormat.loadRegDataSet("/Users/chengli/tmp/housing/train",DataSetType.REG_DENSE,true);
int numComponents = 5;
// KMeans kMeans = new KMeans(numComponents, dataSet);
// // kMeans.randomInitialize();
// kMeans.kmeansPlusPlusInitialize(100);
// List<Double> objectives = new ArrayList<>();
// boolean showInitialize = true;
// if (showInitialize){
// int[] assignment = kMeans.getAssignments();
// for (int k=0;k<numComponents;k++){
// plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_"+(k+1)+"/center.png");
// // plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"_component_"+(k+1)+"_pic_000center.png");
//
// int counter = 0;
// for (int i=0;i<assignment.length;i++){
// if (assignment[i]==k){
// plot(dataSet.getRow(i), 28,28,
// "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_"+(k+1)+"/pic_"+(i+1)+".png");
// counter+=1;
// }
//
// }
// }
// }
// objectives.add(kMeans.objective());
//
//
// for (int iter=1;iter<=5;iter++){
// System.out.println("=====================================");
// System.out.println("iteration "+iter);
// kMeans.iterate();
// objectives.add(kMeans.objective());
// int[] assignment = kMeans.getAssignments();
// for (int k=0;k<numComponents;k++){
// plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"/cluster_"+(k+1)+"/center.png");
// for (int i=0;i<assignment.length;i++){
// if (assignment[i]==k){
// plot(dataSet.getRow(i), 28,28,
// "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"/cluster_"+(k+1)+"/pic_"+(i+1)+".png");
// }
// }
// }
//
// System.out.println("training objective changes: "+objectives);
// }
// int[] assignments = kMeans.getAssignments();
RealMatrix data = new Array2DRowRealMatrix(dataSet.getNumDataPoints(), dataSet.getNumFeatures());
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
for (int j = 0; j < dataSet.getNumFeatures(); j++) {
data.setEntry(i, j, dataSet.getRow(i).get(j));
}
}
GMM gmm = new GMM(dataSet.getNumFeatures(), numComponents, data);
GMMTrainer trainer = new GMMTrainer(data, gmm);
// double[][] gammas = new double[assignments.length][numComponents];
// for (int i=0;i<assignments.length;i++){
// gammas[i][assignments[i]]=1;
// }
// trainer.setGammas(gammas);
System.out.println("start training GMM");
for (int i = 1; i <= 50; i++) {
System.out.println("iteration " + i);
// trainer.mStep();
// trainer.eStep();
trainer.iterate();
double logLikelihood = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(j -> gmm.logDensity(data.getRowVector(j))).sum();
System.out.println("log likelihood = " + logLikelihood);
}
for (int k = 0; k < numComponents; k++) {
System.out.println("component " + k);
System.out.println("mean=" + gmm.getGaussianDistributions()[k].getMean());
System.out.println("log determinant =" + gmm.getGaussianDistributions()[k].getLogDeterminant());
}
// double[][] gammas = trainer.getGammas();
// double[] entropies = IntStream.range(0,dataSet.getNumDataPoints()).mapToDouble(i->Entropy.entropy(gammas[i])).toArray();
// System.out.println(Arrays.toString(entropies));
// int max = ArgMax.argMax(entropies);
//
// System.out.println(Arrays.toString(gammas[max]));
// System.out.println(gmm);
}
Aggregations