Search in sources :

Example 16 with DataSet

use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.

the class KMeansTest method fashion.

private static void fashion() throws Exception {
    FileUtils.cleanDirectory(new File("/Users/chengli/tmp/kmeans_demo"));
    List<String> lines = FileUtils.readLines(new File("/Users/chengli/Dropbox/Shared/CS6220DM/data/fashion/features.txt"));
    Collections.shuffle(lines, new Random(0));
    int rows = 100;
    DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(rows).numFeatures(28 * 28).build();
    for (int i = 0; i < rows; i++) {
        String line = lines.get(i);
        String[] split = line.split(",");
        for (int j = 0; j < split.length; j++) {
            dataSet.setFeatureValue(i, j, Double.parseDouble(split[j]));
        }
    }
    int numComponents = 10;
    KMeans kMeans = new KMeans(numComponents, dataSet);
    // kMeans.randomInitialize();
    kMeans.kmeansPlusPlusInitialize(100);
    List<Double> objectives = new ArrayList<>();
    boolean showInitialize = true;
    if (showInitialize) {
        int[] assignment = kMeans.getAssignments();
        for (int k = 0; k < numComponents; k++) {
            plot(kMeans.getCenters()[k], 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_" + (k + 1) + "/center.png");
            // plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"_component_"+(k+1)+"_pic_000center.png");
            int counter = 0;
            for (int i = 0; i < assignment.length; i++) {
                if (assignment[i] == k) {
                    plot(dataSet.getRow(i), 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_" + (k + 1) + "/pic_" + (i + 1) + ".png");
                    counter += 1;
                }
            // if (counter==5){
            // break;
            // }
            }
        }
    }
    objectives.add(kMeans.objective());
    for (int iter = 1; iter <= 5; iter++) {
        System.out.println("=====================================");
        System.out.println("iteration " + iter);
        kMeans.iterate();
        objectives.add(kMeans.objective());
        int[] assignment = kMeans.getAssignments();
        for (int k = 0; k < numComponents; k++) {
            plot(kMeans.getCenters()[k], 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_" + iter + "/cluster_" + (k + 1) + "/center.png");
            // plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"_component_"+(k+1)+"_pic_000center.png");
            int counter = 0;
            for (int i = 0; i < assignment.length; i++) {
                if (assignment[i] == k) {
                    plot(dataSet.getRow(i), 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_" + iter + "/cluster_" + (k + 1) + "/pic_" + (i + 1) + ".png");
                    counter += 1;
                }
            // if (counter==5){
            // break;
            // }
            }
        }
        System.out.println("training objective changes: " + objectives);
    }
}
Also used : DataSet(edu.neu.ccs.pyramid.dataset.DataSet) File(java.io.File)

Example 17 with DataSet

use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.

the class KMeansTest method mnistSubgroup.

private static void mnistSubgroup(int label) throws Exception {
    FileUtils.cleanDirectory(new File("/Users/chengli/tmp/kmeans_demo"));
    List<Integer> labels = FileUtils.readLines(new File("/Users/chengli/Dropbox/Shared/CS6220DM/2_cluster_EM_mixt/HW2/mnist_labels.txt")).stream().mapToInt(l -> (int) Double.parseDouble(l)).boxed().collect(Collectors.toList());
    List<String> features = FileUtils.readLines(new File("/Users/chengli/Dropbox/Shared/CS6220DM/2_cluster_EM_mixt/HW2/mnist_features.txt"));
    List<String> lines = IntStream.range(0, features.size()).filter(i -> labels.get(i) == label).mapToObj(i -> features.get(i)).collect(Collectors.toList());
    int rows = 100;
    DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(rows).numFeatures(28 * 28).build();
    for (int i = 0; i < rows; i++) {
        String line = lines.get(i);
        String[] split = line.split(" ");
        for (int j = 0; j < split.length; j++) {
            dataSet.setFeatureValue(i, j, Double.parseDouble(split[j]));
        }
    }
    int numComponents = 10;
    KMeans kMeans = new KMeans(numComponents, dataSet);
    // kMeans.randomInitialize();
    kMeans.kmeansPlusPlusInitialize(100);
    List<Double> objectives = new ArrayList<>();
    boolean showInitialize = true;
    if (showInitialize) {
        int[] assignment = kMeans.getAssignments();
        for (int k = 0; k < numComponents; k++) {
            plot(kMeans.getCenters()[k], 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_" + (k + 1) + "/center.png");
            // plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"_component_"+(k+1)+"_pic_000center.png");
            int counter = 0;
            for (int i = 0; i < assignment.length; i++) {
                if (assignment[i] == k) {
                    plot(dataSet.getRow(i), 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_" + (k + 1) + "/pic_" + (i + 1) + ".png");
                    counter += 1;
                }
            // if (counter==5){
            // break;
            // }
            }
        }
    }
    objectives.add(kMeans.objective());
    for (int iter = 1; iter <= 5; iter++) {
        System.out.println("=====================================");
        System.out.println("iteration " + iter);
        kMeans.iterate();
        objectives.add(kMeans.objective());
        int[] assignment = kMeans.getAssignments();
        for (int k = 0; k < numComponents; k++) {
            plot(kMeans.getCenters()[k], 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_" + iter + "/cluster_" + (k + 1) + "/center.png");
            // plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"_component_"+(k+1)+"_pic_000center.png");
            int counter = 0;
            for (int i = 0; i < assignment.length; i++) {
                if (assignment[i] == k) {
                    plot(dataSet.getRow(i), 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_" + iter + "/cluster_" + (k + 1) + "/pic_" + (i + 1) + ".png");
                    counter += 1;
                }
            // if (counter==5){
            // break;
            // }
            }
        }
        System.out.println("training objective changes: " + objectives);
    }
}
Also used : IntStream(java.util.stream.IntStream) java.util(java.util) BufferedImage(java.awt.image.BufferedImage) ArgMax(edu.neu.ccs.pyramid.util.ArgMax) DataSet(edu.neu.ccs.pyramid.dataset.DataSet) FileUtils(org.apache.commons.io.FileUtils) Collectors(java.util.stream.Collectors) File(java.io.File) List(java.util.List) DataSetBuilder(edu.neu.ccs.pyramid.dataset.DataSetBuilder) ImageIO(javax.imageio.ImageIO) Accuracy(edu.neu.ccs.pyramid.eval.Accuracy) Vector(org.apache.mahout.math.Vector) DataSet(edu.neu.ccs.pyramid.dataset.DataSet) File(java.io.File)

Example 18 with DataSet

use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.

the class BMTrainerTest method test6.

private static void test6() {
    DataSet dataSet = DataSetBuilder.getBuilder().numFeatures(5).numDataPoints(20).dense(true).build();
    for (int i = 0; i < 5; i++) {
        dataSet.setFeatureValue(i, 0, 1);
    }
    for (int i = 5; i < 10; i++) {
        dataSet.setFeatureValue(i, 1, 1);
    }
    for (int i = 10; i < 20; i++) {
        dataSet.setFeatureValue(i, 1, 1);
        dataSet.setFeatureValue(i, 2, 1);
        dataSet.setFeatureValue(i, 3, 1);
    }
    System.out.println("dataset = " + dataSet);
    BMTrainer trainer = new BMTrainer(dataSet, 3, 0);
    System.out.println(trainer.bm);
    BM bm = trainer.train();
    // for (int iter=0;iter<100;iter++){
    // trainer.iterate();
    // }
    System.out.println(bm);
    for (int i = 0; i < 5; i++) {
        System.out.println("sample " + i);
        System.out.println(bm.sample());
    }
}
Also used : DataSet(edu.neu.ccs.pyramid.dataset.DataSet)

Example 19 with DataSet

use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.

the class BMTrainerTest method test3.

private static void test3() {
    DataSet dataSet = DataSetBuilder.getBuilder().numFeatures(1).numDataPoints(2).dense(true).build();
    dataSet.setFeatureValue(0, 0, 1);
    System.out.println("dataset = " + dataSet);
    BMTrainer trainer = new BMTrainer(dataSet, 2, 0);
    System.out.println(trainer.bm);
    BM bm = trainer.train();
    System.out.println(bm);
}
Also used : DataSet(edu.neu.ccs.pyramid.dataset.DataSet)

Example 20 with DataSet

use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.

the class BMTrainerTest method test1.

private static void test1() {
    DataSet dataSet = DataSetBuilder.getBuilder().numFeatures(1).numDataPoints(3).dense(true).build();
    dataSet.setFeatureValue(0, 0, 1);
    dataSet.setFeatureValue(1, 0, 1);
    System.out.println("dataset = " + dataSet);
    BMTrainer trainer = new BMTrainer(dataSet, 1, 0);
    System.out.println(trainer.bm);
    BM bm = trainer.train();
    System.out.println(bm);
}
Also used : DataSet(edu.neu.ccs.pyramid.dataset.DataSet)

Aggregations

DataSet (edu.neu.ccs.pyramid.dataset.DataSet)20 Vector (org.apache.mahout.math.Vector)7 File (java.io.File)6 java.util (java.util)4 Collectors (java.util.stream.Collectors)4 IntStream (java.util.stream.IntStream)4 DataSetBuilder (edu.neu.ccs.pyramid.dataset.DataSetBuilder)3 Accuracy (edu.neu.ccs.pyramid.eval.Accuracy)3 ArgMax (edu.neu.ccs.pyramid.util.ArgMax)3 BufferedImage (java.awt.image.BufferedImage)3 List (java.util.List)3 ImageIO (javax.imageio.ImageIO)3 FileUtils (org.apache.commons.io.FileUtils)3 DenseVector (org.apache.mahout.math.DenseVector)3 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)2 Ngram (edu.neu.ccs.pyramid.feature.Ngram)2 IOException (java.io.IOException)2 ClassProbability (edu.neu.ccs.pyramid.classification.ClassProbability)1 PredictionAnalysis (edu.neu.ccs.pyramid.classification.PredictionAnalysis)1 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)1