Search in sources :

Example 1 with KMeans

use of edu.neu.ccs.pyramid.clustering.kmeans.KMeans in project pyramid by cheng-li.

the class GMMTrainerTest method fashion.

private static void fashion() throws Exception {
    FileUtils.cleanDirectory(new File("/Users/chengli/tmp/kmeans_demo"));
    List<String> lines = FileUtils.readLines(new File("/Users/chengli/Dropbox/Shared/CS6220DM/data/fashion/features.txt"));
    Collections.shuffle(lines, new Random(0));
    int rows = 100;
    DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(rows).numFeatures(28 * 28).build();
    for (int i = 0; i < rows; i++) {
        String line = lines.get(i);
        String[] split = line.split(",");
        for (int j = 0; j < split.length; j++) {
            dataSet.setFeatureValue(i, j, Double.parseDouble(split[j]) / 255);
        }
    }
    int numComponents = 5;
    // KMeans kMeans = new KMeans(numComponents, dataSet);
    // //        kMeans.randomInitialize();
    // kMeans.kmeansPlusPlusInitialize(100);
    // List<Double> objectives = new ArrayList<>();
    // boolean showInitialize = true;
    // if (showInitialize){
    // int[] assignment = kMeans.getAssignments();
    // for (int k=0;k<numComponents;k++){
    // plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_"+(k+1)+"/center.png");
    // //                plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"_component_"+(k+1)+"_pic_000center.png");
    // 
    // int counter = 0;
    // for (int i=0;i<assignment.length;i++){
    // if (assignment[i]==k){
    // plot(dataSet.getRow(i), 28,28,
    // "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_"+(k+1)+"/pic_"+(i+1)+".png");
    // counter+=1;
    // }
    // 
    // }
    // }
    // }
    // objectives.add(kMeans.objective());
    // 
    // 
    // for (int iter=1;iter<=5;iter++){
    // System.out.println("=====================================");
    // System.out.println("iteration "+iter);
    // kMeans.iterate();
    // objectives.add(kMeans.objective());
    // int[] assignment = kMeans.getAssignments();
    // for (int k=0;k<numComponents;k++){
    // plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"/cluster_"+(k+1)+"/center.png");
    // for (int i=0;i<assignment.length;i++){
    // if (assignment[i]==k){
    // plot(dataSet.getRow(i), 28,28,
    // "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"/cluster_"+(k+1)+"/pic_"+(i+1)+".png");
    // }
    // }
    // }
    // 
    // System.out.println("training objective changes: "+objectives);
    // }
    // int[] assignments = kMeans.getAssignments();
    RealMatrix data = new Array2DRowRealMatrix(rows, dataSet.getNumFeatures());
    for (int i = 0; i < rows; i++) {
        for (int j = 0; j < dataSet.getNumFeatures(); j++) {
            data.setEntry(i, j, dataSet.getRow(i).get(j));
        }
    }
    GMM gmm = new GMM(dataSet.getNumFeatures(), numComponents, data);
    GMMTrainer trainer = new GMMTrainer(data, gmm);
    // double[][] gammas = new double[assignments.length][numComponents];
    // for (int i=0;i<assignments.length;i++){
    // gammas[i][assignments[i]]=1;
    // }
    // trainer.setGammas(gammas);
    System.out.println("start training GMM");
    for (int i = 1; i <= 5; i++) {
        // trainer.mStep();
        // trainer.eStep();
        trainer.iterate();
        System.out.println("iteration " + i);
        // double[] entropies = IntStream.range(0,rows).mapToDouble(i->Entropy.entropy(gammas[i])).toArray();
        // System.out.println(Arrays.toString(entropies));
        // int max = ArgMax.argMax(entropies);
        // plot(dataSet.getRow(max), 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/max_entropy.png");
        // System.out.println(Arrays.toString(gammas[max]));
        double logLikelihood = IntStream.range(0, rows).parallel().mapToDouble(j -> gmm.logDensity(data.getRowVector(j))).sum();
        System.out.println("log likelihood = " + logLikelihood);
        for (int k = 0; k < numComponents; k++) {
            plot(gmm.getGaussianDistributions()[k].getMean(), 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_" + i + "/cluster_" + (k + 1) + "/center.png");
        // for (int i=0;i<assignment.length;i++){
        // if (assignment[i]==k){
        // plot(dataSet.getRow(i), 28,28,
        // "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"/cluster_"+(k+1)+"/pic_"+(i+1)+".png");
        // }
        // }
        }
    }
    for (int k = 0; k < numComponents; k++) {
        System.out.println("component " + k);
        System.out.println("mean=" + gmm.getGaussianDistributions()[k].getMean());
        System.out.println("log determinant =" + gmm.getGaussianDistributions()[k].getLogDeterminant());
    }
// double[][] gammas = trainer.getGammas();
// double[] entropies = IntStream.range(0,rows).mapToDouble(i->Entropy.entropy(gammas[i])).toArray();
// System.out.println(Arrays.toString(entropies));
// int max = ArgMax.argMax(entropies);
// plot(dataSet.getRow(max), 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/max_entropy.png");
// System.out.println(Arrays.toString(gammas[max]));
// System.out.println(gmm);
}
Also used : IntStream(java.util.stream.IntStream) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) java.util(java.util) BufferedImage(java.awt.image.BufferedImage) ArgMax(edu.neu.ccs.pyramid.util.ArgMax) FileUtils(org.apache.commons.io.FileUtils) RealVector(org.apache.commons.math3.linear.RealVector) File(java.io.File) KMeans(edu.neu.ccs.pyramid.clustering.kmeans.KMeans) Serialization(edu.neu.ccs.pyramid.util.Serialization) DescriptiveStatistics(org.apache.commons.math3.stat.descriptive.DescriptiveStatistics) Entropy(edu.neu.ccs.pyramid.eval.Entropy) ImageIO(javax.imageio.ImageIO) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) RealMatrix(org.apache.commons.math3.linear.RealMatrix) BM(edu.neu.ccs.pyramid.clustering.bm.BM) BMSelector(edu.neu.ccs.pyramid.clustering.bm.BMSelector) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) File(java.io.File)

Example 2 with KMeans

use of edu.neu.ccs.pyramid.clustering.kmeans.KMeans in project pyramid by cheng-li.

the class GMMTrainerTest method spam.

private static void spam() throws Exception {
    FileUtils.cleanDirectory(new File("/Users/chengli/tmp/kmeans_demo"));
    DataSet dataSet = TRECFormat.loadClfDataSet("/Users/chengli/tmp/spam/train", DataSetType.CLF_DENSE, true);
    // DataSet dataSet = TRECFormat.loadRegDataSet("/Users/chengli/tmp/housing/train",DataSetType.REG_DENSE,true);
    int numComponents = 5;
    // KMeans kMeans = new KMeans(numComponents, dataSet);
    // //        kMeans.randomInitialize();
    // kMeans.kmeansPlusPlusInitialize(100);
    // List<Double> objectives = new ArrayList<>();
    // boolean showInitialize = true;
    // if (showInitialize){
    // int[] assignment = kMeans.getAssignments();
    // for (int k=0;k<numComponents;k++){
    // plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_"+(k+1)+"/center.png");
    // //                plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"_component_"+(k+1)+"_pic_000center.png");
    // 
    // int counter = 0;
    // for (int i=0;i<assignment.length;i++){
    // if (assignment[i]==k){
    // plot(dataSet.getRow(i), 28,28,
    // "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_"+(k+1)+"/pic_"+(i+1)+".png");
    // counter+=1;
    // }
    // 
    // }
    // }
    // }
    // objectives.add(kMeans.objective());
    // 
    // 
    // for (int iter=1;iter<=5;iter++){
    // System.out.println("=====================================");
    // System.out.println("iteration "+iter);
    // kMeans.iterate();
    // objectives.add(kMeans.objective());
    // int[] assignment = kMeans.getAssignments();
    // for (int k=0;k<numComponents;k++){
    // plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"/cluster_"+(k+1)+"/center.png");
    // for (int i=0;i<assignment.length;i++){
    // if (assignment[i]==k){
    // plot(dataSet.getRow(i), 28,28,
    // "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"/cluster_"+(k+1)+"/pic_"+(i+1)+".png");
    // }
    // }
    // }
    // 
    // System.out.println("training objective changes: "+objectives);
    // }
    // int[] assignments = kMeans.getAssignments();
    RealMatrix data = new Array2DRowRealMatrix(dataSet.getNumDataPoints(), dataSet.getNumFeatures());
    for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
        for (int j = 0; j < dataSet.getNumFeatures(); j++) {
            data.setEntry(i, j, dataSet.getRow(i).get(j));
        }
    }
    GMM gmm = new GMM(dataSet.getNumFeatures(), numComponents, data);
    GMMTrainer trainer = new GMMTrainer(data, gmm);
    // double[][] gammas = new double[assignments.length][numComponents];
    // for (int i=0;i<assignments.length;i++){
    // gammas[i][assignments[i]]=1;
    // }
    // trainer.setGammas(gammas);
    System.out.println("start training GMM");
    for (int i = 1; i <= 50; i++) {
        System.out.println("iteration " + i);
        // trainer.mStep();
        // trainer.eStep();
        trainer.iterate();
        double logLikelihood = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(j -> gmm.logDensity(data.getRowVector(j))).sum();
        System.out.println("log likelihood = " + logLikelihood);
    }
    for (int k = 0; k < numComponents; k++) {
        System.out.println("component " + k);
        System.out.println("mean=" + gmm.getGaussianDistributions()[k].getMean());
        System.out.println("log determinant =" + gmm.getGaussianDistributions()[k].getLogDeterminant());
    }
// double[][] gammas = trainer.getGammas();
// double[] entropies = IntStream.range(0,dataSet.getNumDataPoints()).mapToDouble(i->Entropy.entropy(gammas[i])).toArray();
// System.out.println(Arrays.toString(entropies));
// int max = ArgMax.argMax(entropies);
// 
// System.out.println(Arrays.toString(gammas[max]));
// System.out.println(gmm);
}
Also used : IntStream(java.util.stream.IntStream) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) java.util(java.util) BufferedImage(java.awt.image.BufferedImage) ArgMax(edu.neu.ccs.pyramid.util.ArgMax) FileUtils(org.apache.commons.io.FileUtils) RealVector(org.apache.commons.math3.linear.RealVector) File(java.io.File) KMeans(edu.neu.ccs.pyramid.clustering.kmeans.KMeans) Serialization(edu.neu.ccs.pyramid.util.Serialization) DescriptiveStatistics(org.apache.commons.math3.stat.descriptive.DescriptiveStatistics) Entropy(edu.neu.ccs.pyramid.eval.Entropy) ImageIO(javax.imageio.ImageIO) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) RealMatrix(org.apache.commons.math3.linear.RealMatrix) BM(edu.neu.ccs.pyramid.clustering.bm.BM) BMSelector(edu.neu.ccs.pyramid.clustering.bm.BMSelector) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) File(java.io.File)

Aggregations

BM (edu.neu.ccs.pyramid.clustering.bm.BM)2 BMSelector (edu.neu.ccs.pyramid.clustering.bm.BMSelector)2 KMeans (edu.neu.ccs.pyramid.clustering.kmeans.KMeans)2 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)2 Entropy (edu.neu.ccs.pyramid.eval.Entropy)2 ArgMax (edu.neu.ccs.pyramid.util.ArgMax)2 Serialization (edu.neu.ccs.pyramid.util.Serialization)2 BufferedImage (java.awt.image.BufferedImage)2 File (java.io.File)2 java.util (java.util)2 IntStream (java.util.stream.IntStream)2 ImageIO (javax.imageio.ImageIO)2 FileUtils (org.apache.commons.io.FileUtils)2 Array2DRowRealMatrix (org.apache.commons.math3.linear.Array2DRowRealMatrix)2 RealMatrix (org.apache.commons.math3.linear.RealMatrix)2 RealVector (org.apache.commons.math3.linear.RealVector)2 DescriptiveStatistics (org.apache.commons.math3.stat.descriptive.DescriptiveStatistics)2