Search in sources :

Example 6 with DataSet

use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.

the class BMTrainerTest method test5.

private static void test5() {
    DataSet dataSet = DataSetBuilder.getBuilder().numFeatures(5).numDataPoints(20).dense(true).build();
    for (int i = 0; i < 5; i++) {
        dataSet.setFeatureValue(i, 0, 1);
    }
    for (int i = 5; i < 10; i++) {
        dataSet.setFeatureValue(i, 1, 1);
    }
    for (int i = 10; i < 20; i++) {
        dataSet.setFeatureValue(i, 2, 1);
        dataSet.setFeatureValue(i, 3, 1);
    }
    System.out.println("dataset = " + dataSet);
    BMTrainer trainer = new BMTrainer(dataSet, 3, 0);
    System.out.println(trainer.bm);
    trainer.train();
    // for (int iter=0;iter<100;iter++){
    // trainer.iterate();
    // }
    System.out.println(trainer.bm);
    Vector vector1 = new DenseVector(5);
    vector1.set(0, 1);
    Vector vector2 = new DenseVector(5);
    vector2.set(1, 1);
    Vector vector3 = new DenseVector(5);
    vector3.set(2, 1);
    vector3.set(3, 1);
    System.out.println(Math.exp(trainer.bm.logProbability(vector1)));
    System.out.println(Math.exp(trainer.bm.logProbability(vector2)));
    System.out.println(Math.exp(trainer.bm.logProbability(vector3)));
}
Also used : DataSet(edu.neu.ccs.pyramid.dataset.DataSet) DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector) DenseVector(org.apache.mahout.math.DenseVector)

Example 7 with DataSet

use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.

the class BMTrainerTest method test2.

private static void test2() {
    DataSet dataSet = DataSetBuilder.getBuilder().numFeatures(1).numDataPoints(2).dense(true).build();
    dataSet.setFeatureValue(0, 0, 1);
    System.out.println("dataset = " + dataSet);
    BMTrainer trainer = new BMTrainer(dataSet, 1, 0);
    System.out.println(trainer.bm);
    BM bm = trainer.train();
    System.out.println(bm);
}
Also used : DataSet(edu.neu.ccs.pyramid.dataset.DataSet)

Example 8 with DataSet

use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.

the class BMTrainerTest method test4.

private static void test4() {
    DataSet dataSet = DataSetBuilder.getBuilder().numFeatures(2).numDataPoints(10).dense(true).build();
    for (int i = 0; i < 5; i++) {
        dataSet.setFeatureValue(i, 0, 1);
    }
    for (int i = 5; i < 10; i++) {
        dataSet.setFeatureValue(i, 1, 1);
    }
    System.out.println("dataset = " + dataSet);
    BMTrainer trainer = new BMTrainer(dataSet, 2, 0);
    System.out.println(trainer.bm);
    trainer.train();
    // for (int iter=0;iter<100;iter++){
    // trainer.iterate();
    // }
    System.out.println(trainer.bm);
}
Also used : DataSet(edu.neu.ccs.pyramid.dataset.DataSet)

Example 9 with DataSet

use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.

the class KMeansTest method mnist.

private static void mnist() throws Exception {
    FileUtils.cleanDirectory(new File("/Users/chengli/tmp/kmeans_demo"));
    List<String> lines = FileUtils.readLines(new File("/Users/chengli/Dropbox/Shared/CS6220DM/2_cluster_EM_mixt/HW2/mnist/mnist_features.txt"));
    List<Integer> intputLabels = FileUtils.readLines(new File("/Users/chengli/Dropbox/Shared/CS6220DM/2_cluster_EM_mixt/HW2/mnist/mnist_labels.txt")).stream().mapToInt(line -> (int) Double.parseDouble(line)).boxed().collect(Collectors.toList());
    Collections.shuffle(lines, new Random(0));
    Collections.shuffle(intputLabels, new Random(0));
    int rows = 100;
    int[] labels = intputLabels.stream().limit(rows).mapToInt(a -> a).toArray();
    DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(rows).numFeatures(28 * 28).build();
    for (int i = 0; i < rows; i++) {
        String line = lines.get(i);
        String[] split = line.split(" ");
        for (int j = 0; j < split.length; j++) {
            dataSet.setFeatureValue(i, j, Double.parseDouble(split[j]));
        }
    }
    int numComponents = 20;
    KMeans kMeans = new KMeans(numComponents, dataSet);
    // kMeans.randomInitialize();
    kMeans.kmeansPlusPlusInitialize(100);
    List<Double> objectives = new ArrayList<>();
    boolean showInitialize = true;
    if (showInitialize) {
        int[] assignment = kMeans.getAssignments();
        for (int k = 0; k < numComponents; k++) {
            plot(kMeans.getCenters()[k], 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_" + (k + 1) + "/center.png");
            // plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"_component_"+(k+1)+"_pic_000center.png");
            int counter = 0;
            for (int i = 0; i < assignment.length; i++) {
                if (assignment[i] == k) {
                    plot(dataSet.getRow(i), 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_" + (k + 1) + "/pic_" + (i + 1) + ".png");
                    counter += 1;
                }
            // if (counter==5){
            // break;
            // }
            }
        }
    }
    objectives.add(kMeans.objective());
    for (int iter = 1; iter <= 5; iter++) {
        System.out.println("=====================================");
        System.out.println("iteration " + iter);
        kMeans.iterate();
        objectives.add(kMeans.objective());
        int[] assignment = kMeans.getAssignments();
        for (int k = 0; k < numComponents; k++) {
            plot(kMeans.getCenters()[k], 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_" + iter + "/cluster_" + (k + 1) + "/center.png");
            // plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"_component_"+(k+1)+"_pic_000center.png");
            int counter = 0;
            for (int i = 0; i < assignment.length; i++) {
                if (assignment[i] == k) {
                    plot(dataSet.getRow(i), 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_" + iter + "/cluster_" + (k + 1) + "/pic_" + (i + 1) + ".png");
                    counter += 1;
                }
            // if (counter==5){
            // break;
            // }
            }
        }
    }
    int[] assignment = kMeans.getAssignments();
    System.out.println("purity=" + purity(assignment, labels, 10, numComponents));
    System.out.println("training objective changes: " + objectives);
}
Also used : IntStream(java.util.stream.IntStream) java.util(java.util) BufferedImage(java.awt.image.BufferedImage) ArgMax(edu.neu.ccs.pyramid.util.ArgMax) DataSet(edu.neu.ccs.pyramid.dataset.DataSet) FileUtils(org.apache.commons.io.FileUtils) Collectors(java.util.stream.Collectors) File(java.io.File) List(java.util.List) DataSetBuilder(edu.neu.ccs.pyramid.dataset.DataSetBuilder) ImageIO(javax.imageio.ImageIO) Accuracy(edu.neu.ccs.pyramid.eval.Accuracy) Vector(org.apache.mahout.math.Vector) DataSet(edu.neu.ccs.pyramid.dataset.DataSet) File(java.io.File)

Example 10 with DataSet

use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.

the class KMeansTest method extractFashionImages.

private static void extractFashionImages() throws Exception {
    List<String> lines = FileUtils.readLines(new File("/Users/chengli/Dropbox/Shared/CS6220DM/data/fashion/features.txt"));
    Collections.shuffle(lines, new Random(0));
    int rows = 100;
    DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(rows).numFeatures(28 * 28).build();
    for (int i = 0; i < rows; i++) {
        String line = lines.get(i);
        String[] split = line.split(",");
        for (int j = 0; j < split.length; j++) {
            dataSet.setFeatureValue(i, j, Double.parseDouble(split[j]));
        }
    }
    for (int i = 0; i < rows; i++) {
        plot(dataSet.getRow(i), 28, 28, "/Users/chengli/tmp/fashion/pic_" + (i + 1) + ".png");
    }
}
Also used : DataSet(edu.neu.ccs.pyramid.dataset.DataSet) File(java.io.File)

Aggregations

DataSet (edu.neu.ccs.pyramid.dataset.DataSet)20 Vector (org.apache.mahout.math.Vector)7 File (java.io.File)6 java.util (java.util)4 Collectors (java.util.stream.Collectors)4 IntStream (java.util.stream.IntStream)4 DataSetBuilder (edu.neu.ccs.pyramid.dataset.DataSetBuilder)3 Accuracy (edu.neu.ccs.pyramid.eval.Accuracy)3 ArgMax (edu.neu.ccs.pyramid.util.ArgMax)3 BufferedImage (java.awt.image.BufferedImage)3 List (java.util.List)3 ImageIO (javax.imageio.ImageIO)3 FileUtils (org.apache.commons.io.FileUtils)3 DenseVector (org.apache.mahout.math.DenseVector)3 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)2 Ngram (edu.neu.ccs.pyramid.feature.Ngram)2 IOException (java.io.IOException)2 ClassProbability (edu.neu.ccs.pyramid.classification.ClassProbability)1 PredictionAnalysis (edu.neu.ccs.pyramid.classification.PredictionAnalysis)1 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)1