Search in sources :

Example 11 with DataSet

use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.

the class KMeansTest method fashiSubgroup.

private static void fashiSubgroup(int label) throws Exception {
    FileUtils.cleanDirectory(new File("/Users/chengli/tmp/kmeans_demo"));
    List<Integer> labels = FileUtils.readLines(new File("/Users/chengli/Dropbox/Shared/CS6220DM/data/fashion/labels.txt")).stream().mapToInt(l -> (int) Double.parseDouble(l)).boxed().collect(Collectors.toList());
    List<String> features = FileUtils.readLines(new File("/Users/chengli/Dropbox/Shared/CS6220DM/data/fashion/features.txt"));
    List<String> lines = IntStream.range(0, features.size()).filter(i -> labels.get(i) == label).mapToObj(i -> features.get(i)).collect(Collectors.toList());
    int rows = 100;
    DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(rows).numFeatures(28 * 28).build();
    for (int i = 0; i < rows; i++) {
        String line = lines.get(i);
        String[] split = line.split(",");
        for (int j = 0; j < split.length; j++) {
            dataSet.setFeatureValue(i, j, Double.parseDouble(split[j]));
        }
    }
    int numComponents = 10;
    KMeans kMeans = new KMeans(numComponents, dataSet);
    // kMeans.randomInitialize();
    kMeans.kmeansPlusPlusInitialize(100);
    List<Double> objectives = new ArrayList<>();
    boolean showInitialize = true;
    if (showInitialize) {
        int[] assignment = kMeans.getAssignments();
        for (int k = 0; k < numComponents; k++) {
            plot(kMeans.getCenters()[k], 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_" + (k + 1) + "/center.png");
            // plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"_component_"+(k+1)+"_pic_000center.png");
            int counter = 0;
            for (int i = 0; i < assignment.length; i++) {
                if (assignment[i] == k) {
                    plot(dataSet.getRow(i), 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_" + (k + 1) + "/pic_" + (i + 1) + ".png");
                    counter += 1;
                }
            // if (counter==5){
            // break;
            // }
            }
        }
    }
    objectives.add(kMeans.objective());
    for (int iter = 1; iter <= 5; iter++) {
        System.out.println("=====================================");
        System.out.println("iteration " + iter);
        kMeans.iterate();
        objectives.add(kMeans.objective());
        int[] assignment = kMeans.getAssignments();
        for (int k = 0; k < numComponents; k++) {
            plot(kMeans.getCenters()[k], 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_" + iter + "/cluster_" + (k + 1) + "/center.png");
            // plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"_component_"+(k+1)+"_pic_000center.png");
            int counter = 0;
            for (int i = 0; i < assignment.length; i++) {
                if (assignment[i] == k) {
                    plot(dataSet.getRow(i), 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_" + iter + "/cluster_" + (k + 1) + "/pic_" + (i + 1) + ".png");
                    counter += 1;
                }
            // if (counter==5){
            // break;
            // }
            }
        }
        System.out.println("training objective changes: " + objectives);
    }
}
Also used : IntStream(java.util.stream.IntStream) java.util(java.util) BufferedImage(java.awt.image.BufferedImage) ArgMax(edu.neu.ccs.pyramid.util.ArgMax) DataSet(edu.neu.ccs.pyramid.dataset.DataSet) FileUtils(org.apache.commons.io.FileUtils) Collectors(java.util.stream.Collectors) File(java.io.File) List(java.util.List) DataSetBuilder(edu.neu.ccs.pyramid.dataset.DataSetBuilder) ImageIO(javax.imageio.ImageIO) Accuracy(edu.neu.ccs.pyramid.eval.Accuracy) Vector(org.apache.mahout.math.Vector) DataSet(edu.neu.ccs.pyramid.dataset.DataSet) File(java.io.File)

Example 12 with DataSet

use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.

the class KMeansTest method extractMnistImages.

private static void extractMnistImages() throws Exception {
    List<String> lines = FileUtils.readLines(new File("/Users/chengli/Dropbox/Shared/CS6220DM/2_cluster_EM_mixt/HW2/mnist_features.txt"));
    Collections.shuffle(lines, new Random(0));
    int rows = 100;
    DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(rows).numFeatures(28 * 28).build();
    for (int i = 0; i < rows; i++) {
        String line = lines.get(i);
        String[] split = line.split(" ");
        for (int j = 0; j < split.length; j++) {
            dataSet.setFeatureValue(i, j, Double.parseDouble(split[j]));
        }
    }
    for (int i = 0; i < rows; i++) {
        plot(dataSet.getRow(i), 28, 28, "/Users/chengli/tmp/mnist/pic_" + (i + 1) + ".png");
    }
}
Also used : DataSet(edu.neu.ccs.pyramid.dataset.DataSet) File(java.io.File)

Example 13 with DataSet

use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.

the class StumpSelector method scores.

/**
     *
     * @param index
     * @param labels size = num labels * num data
     * @param feature
     * @param idTranslator
     * @param matchScoreType
     * @param docFilter
     */
public static double[] scores(ESIndex index, double[][] labels, Ngram feature, IdTranslator idTranslator, FeatureLoader.MatchScoreType matchScoreType, String docFilter) {
    Ngram ngram = null;
    try {
        ngram = (Ngram) Serialization.deepCopy(feature);
    } catch (IOException e) {
        e.printStackTrace();
    } catch (ClassNotFoundException e) {
        e.printStackTrace();
    }
    ngram.setIndex(0);
    DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(labels[0].length).numFeatures(1).build();
    FeatureLoader.loadNgramFeature(index, dataSet, ngram, idTranslator, matchScoreType, docFilter);
    double[] scores = new double[labels.length];
    for (int l = 0; l < scores.length; l++) {
        double score = score(dataSet, labels[l]);
        scores[l] = score;
    }
    return scores;
}
Also used : DataSet(edu.neu.ccs.pyramid.dataset.DataSet) IOException(java.io.IOException) Ngram(edu.neu.ccs.pyramid.feature.Ngram)

Example 14 with DataSet

use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.

the class FeatureBinarizer method binarize.

private static void binarize(Config config, String inputData, String outputData) throws Exception {
    String dataType = config.getString("dataSetType");
    DataSet dataSet;
    switch(dataType) {
        case "reg":
            dataSet = TRECFormat.loadRegDataSet(inputData, DataSetType.REG_SPARSE, true);
            break;
        case "clf":
            dataSet = TRECFormat.loadClfDataSet(inputData, DataSetType.CLF_SPARSE, true);
            break;
        case "multiLabel":
            dataSet = TRECFormat.loadMultiLabelClfDataSet(inputData, DataSetType.ML_CLF_SPARSE, true);
            break;
        default:
            throw new IllegalArgumentException("unknown type");
    }
    DataSetUtil.binarizeFeature(dataSet);
    TRECFormat.save(dataSet, outputData);
}
Also used : DataSet(edu.neu.ccs.pyramid.dataset.DataSet)

Example 15 with DataSet

use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.

the class BMSelector method selectGammas.

public static double[][] selectGammas(int numClasses, MultiLabel[] multiLabels, int numClusters) {
    DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(multiLabels.length).numFeatures(numClasses).density(Density.SPARSE_RANDOM).build();
    for (int i = 0; i < multiLabels.length; i++) {
        MultiLabel multiLabel = multiLabels[i];
        for (int label : multiLabel.getMatchedLabels()) {
            dataSet.setFeatureValue(i, label, 1);
        }
    }
    BMTrainer trainer = BMSelector.selectTrainer(dataSet, numClusters, 10);
    // System.out.println("gamma = "+ Arrays.deepToString(trainer.gammas));
    return trainer.gammas;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DataSet(edu.neu.ccs.pyramid.dataset.DataSet)

Aggregations

DataSet (edu.neu.ccs.pyramid.dataset.DataSet)20 Vector (org.apache.mahout.math.Vector)7 File (java.io.File)6 java.util (java.util)4 Collectors (java.util.stream.Collectors)4 IntStream (java.util.stream.IntStream)4 DataSetBuilder (edu.neu.ccs.pyramid.dataset.DataSetBuilder)3 Accuracy (edu.neu.ccs.pyramid.eval.Accuracy)3 ArgMax (edu.neu.ccs.pyramid.util.ArgMax)3 BufferedImage (java.awt.image.BufferedImage)3 List (java.util.List)3 ImageIO (javax.imageio.ImageIO)3 FileUtils (org.apache.commons.io.FileUtils)3 DenseVector (org.apache.mahout.math.DenseVector)3 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)2 Ngram (edu.neu.ccs.pyramid.feature.Ngram)2 IOException (java.io.IOException)2 ClassProbability (edu.neu.ccs.pyramid.classification.ClassProbability)1 PredictionAnalysis (edu.neu.ccs.pyramid.classification.PredictionAnalysis)1 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)1