Search in sources :

Example 66 with Array2DRowRealMatrix

use of org.apache.commons.math3.linear.Array2DRowRealMatrix in project pyramid by cheng-li.

the class GMMDemo method main.

public static void main(String[] args) throws Exception {
    List<String> lines = FileUtils.readLines(new File("/Users/chengli/Dropbox/Shared/CS6220DM/2_cluster_EM_mixt/HW2/mnist_features.txt"));
    Collections.shuffle(lines);
    int dim = lines.get(0).split(" ").length;
    int rows = 100;
    RealMatrix data = new Array2DRowRealMatrix(rows, dim);
    for (int i = 0; i < rows; i++) {
        String[] split = lines.get(i).split(" ");
        for (int j = 0; j < dim; j++) {
            data.setEntry(i, j, Double.parseDouble(split[j]) + Math.random());
        }
    }
    double[] mins = new double[data.getColumnDimension()];
    double[] maxs = new double[data.getColumnDimension()];
    double[] vars = new double[data.getColumnDimension()];
    for (int j = 0; j < data.getColumnDimension(); j++) {
        RealVector column = data.getColumnVector(j);
        mins[j] = column.getMinValue();
        maxs[j] = column.getMaxValue();
        DescriptiveStatistics stats = new DescriptiveStatistics(column.toArray());
        vars[j] = stats.getVariance();
    }
// DataSet dataSet = DataSetBuilder.getBuilder()
// .numDataPoints(rows)
// .numFeatures(data.getColumnDimension())
// .build();
// for (int i=0;i<dataSet.getNumDataPoints();i++){
// for (int j=0;j<dataSet.getNumFeatures();j++){
// if (data.getEntry(i,j)>255.0/2){
// dataSet.setFeatureValue(i,j,1);
// } else {
// dataSet.setFeatureValue(i,j,0);
// }
// 
// }
// }
// 
// int numComponents = 10;
// 
// 
// BM bm = BMSelector.select(dataSet, numComponents, 100);
// System.out.println(Arrays.toString(bm.getMixtureCoefficients()));
// StringBuilder stringBuilder = new StringBuilder();
// for (int k=0;k<numComponents;k++){
// for (int d=0;d<dataSet.getNumFeatures();d++){
// stringBuilder.append(bm.getDistributions()[k][d].getP());
// if (d!=dataSet.getNumFeatures()-1){
// stringBuilder.append(",");
// }
// }
// stringBuilder.append("\n");
// }
// FileUtils.writeStringToFile(new File("/Users/chengli/tmp/gmm/bm"),stringBuilder.toString());
// BMTrainer bmTrainer = BMSelector.selectTrainer(dataSet,numComponents,50);
// 
// GMM gmm = new GMM(dim,numComponents, data);
// 
// GMMTrainer trainer = new GMMTrainer(data, gmm);
// 
// trainer.setGammas(bmTrainer.getGammas());
// trainer.mStep();
// 
// for (int i=1;i<=5;i++){
// System.out.println("iteration = "+i);
// trainer.iterate();
// double logLikelihood = IntStream.range(0,rows).parallel()
// .mapToDouble(j->gmm.logDensity(data.getRowVector(j))).sum();
// System.out.println("log likelihood = "+logLikelihood);
// Serialization.serialize(gmm, "/Users/chengli/tmp/gmm/model_iter_"+i);
// for (int k=0;k<gmm.getNumComponents();k++){
// FileUtils.writeStringToFile(new File("/Users/chengli/tmp/gmm/mean_iter_"+i+"_component_"+(k+1)),
// gmm.getGaussianDistributions()[k].getMean().toString().replace("{","")
// .replace("}","").replace(";",","));
// }
// }
// 
// FileUtils.writeStringToFile(new File("/Users/chengli/tmp/gmm/modeltext"), gmm.toString());
// GMM gmm = (GMM) Serialization.deserialize("/Users/chengli/tmp/gmm/model_3");
}
Also used : DescriptiveStatistics(org.apache.commons.math3.stat.descriptive.DescriptiveStatistics) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RealVector(org.apache.commons.math3.linear.RealVector) File(java.io.File)

Example 67 with Array2DRowRealMatrix

use of org.apache.commons.math3.linear.Array2DRowRealMatrix in project pyramid by cheng-li.

the class GMMTrainerTest method test3.

private static void test3() throws Exception {
    List<String> lines = FileUtils.readLines(new File("/Users/chengli/Downloads/fashion-mnist/features.txt"));
    // Collections.shuffle(lines);
    int dim = 784;
    int rows = 100;
    RealMatrix data = new Array2DRowRealMatrix(rows, dim);
    for (int i = 0; i < rows; i++) {
        String[] split = lines.get(i).split(",");
        System.out.println(Arrays.toString(split));
        for (int j = 0; j < dim; j++) {
            data.setEntry(i, j, Double.parseDouble(split[j]) + Math.random());
        }
    }
    int numComponents = 10;
    GMM gmm = new GMM(dim, numComponents, data);
    GMMTrainer trainer = new GMMTrainer(data, gmm);
    for (int i = 1; i <= 5; i++) {
        System.out.println("iteration = " + i);
        trainer.iterate();
        double logLikelihood = IntStream.range(0, rows).parallel().mapToDouble(j -> gmm.logDensity(data.getRowVector(j))).sum();
        System.out.println("log likelihood = " + logLikelihood);
        Serialization.serialize(gmm, "/Users/chengli/tmp/gmm/model_iter_" + i);
        for (int k = 0; k < gmm.getNumComponents(); k++) {
            FileUtils.writeStringToFile(new File("/Users/chengli/tmp/gmm/mean_iter_" + i + "_component_" + (k + 1)), gmm.getGaussianDistributions()[k].getMean().toString().replace("{", "").replace("}", "").replace(";", ","));
        }
    }
}
Also used : IntStream(java.util.stream.IntStream) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) java.util(java.util) BufferedImage(java.awt.image.BufferedImage) ArgMax(edu.neu.ccs.pyramid.util.ArgMax) FileUtils(org.apache.commons.io.FileUtils) RealVector(org.apache.commons.math3.linear.RealVector) File(java.io.File) KMeans(edu.neu.ccs.pyramid.clustering.kmeans.KMeans) Serialization(edu.neu.ccs.pyramid.util.Serialization) DescriptiveStatistics(org.apache.commons.math3.stat.descriptive.DescriptiveStatistics) Entropy(edu.neu.ccs.pyramid.eval.Entropy) ImageIO(javax.imageio.ImageIO) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) RealMatrix(org.apache.commons.math3.linear.RealMatrix) BM(edu.neu.ccs.pyramid.clustering.bm.BM) BMSelector(edu.neu.ccs.pyramid.clustering.bm.BMSelector) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) File(java.io.File)

Example 68 with Array2DRowRealMatrix

use of org.apache.commons.math3.linear.Array2DRowRealMatrix in project pyramid by cheng-li.

the class GMMTrainerTest method fashion.

private static void fashion() throws Exception {
    FileUtils.cleanDirectory(new File("/Users/chengli/tmp/kmeans_demo"));
    List<String> lines = FileUtils.readLines(new File("/Users/chengli/Dropbox/Shared/CS6220DM/data/fashion/features.txt"));
    Collections.shuffle(lines, new Random(0));
    int rows = 100;
    DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(rows).numFeatures(28 * 28).build();
    for (int i = 0; i < rows; i++) {
        String line = lines.get(i);
        String[] split = line.split(",");
        for (int j = 0; j < split.length; j++) {
            dataSet.setFeatureValue(i, j, Double.parseDouble(split[j]) / 255);
        }
    }
    int numComponents = 5;
    // KMeans kMeans = new KMeans(numComponents, dataSet);
    // //        kMeans.randomInitialize();
    // kMeans.kmeansPlusPlusInitialize(100);
    // List<Double> objectives = new ArrayList<>();
    // boolean showInitialize = true;
    // if (showInitialize){
    // int[] assignment = kMeans.getAssignments();
    // for (int k=0;k<numComponents;k++){
    // plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_"+(k+1)+"/center.png");
    // //                plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"_component_"+(k+1)+"_pic_000center.png");
    // 
    // int counter = 0;
    // for (int i=0;i<assignment.length;i++){
    // if (assignment[i]==k){
    // plot(dataSet.getRow(i), 28,28,
    // "/Users/chengli/tmp/kmeans_demo/clusters/initial/cluster_"+(k+1)+"/pic_"+(i+1)+".png");
    // counter+=1;
    // }
    // 
    // }
    // }
    // }
    // objectives.add(kMeans.objective());
    // 
    // 
    // for (int iter=1;iter<=5;iter++){
    // System.out.println("=====================================");
    // System.out.println("iteration "+iter);
    // kMeans.iterate();
    // objectives.add(kMeans.objective());
    // int[] assignment = kMeans.getAssignments();
    // for (int k=0;k<numComponents;k++){
    // plot(kMeans.getCenters()[k], 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"/cluster_"+(k+1)+"/center.png");
    // for (int i=0;i<assignment.length;i++){
    // if (assignment[i]==k){
    // plot(dataSet.getRow(i), 28,28,
    // "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"/cluster_"+(k+1)+"/pic_"+(i+1)+".png");
    // }
    // }
    // }
    // 
    // System.out.println("training objective changes: "+objectives);
    // }
    // int[] assignments = kMeans.getAssignments();
    RealMatrix data = new Array2DRowRealMatrix(rows, dataSet.getNumFeatures());
    for (int i = 0; i < rows; i++) {
        for (int j = 0; j < dataSet.getNumFeatures(); j++) {
            data.setEntry(i, j, dataSet.getRow(i).get(j));
        }
    }
    GMM gmm = new GMM(dataSet.getNumFeatures(), numComponents, data);
    GMMTrainer trainer = new GMMTrainer(data, gmm);
    // double[][] gammas = new double[assignments.length][numComponents];
    // for (int i=0;i<assignments.length;i++){
    // gammas[i][assignments[i]]=1;
    // }
    // trainer.setGammas(gammas);
    System.out.println("start training GMM");
    for (int i = 1; i <= 5; i++) {
        // trainer.mStep();
        // trainer.eStep();
        trainer.iterate();
        System.out.println("iteration " + i);
        // double[] entropies = IntStream.range(0,rows).mapToDouble(i->Entropy.entropy(gammas[i])).toArray();
        // System.out.println(Arrays.toString(entropies));
        // int max = ArgMax.argMax(entropies);
        // plot(dataSet.getRow(max), 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/max_entropy.png");
        // System.out.println(Arrays.toString(gammas[max]));
        double logLikelihood = IntStream.range(0, rows).parallel().mapToDouble(j -> gmm.logDensity(data.getRowVector(j))).sum();
        System.out.println("log likelihood = " + logLikelihood);
        for (int k = 0; k < numComponents; k++) {
            plot(gmm.getGaussianDistributions()[k].getMean(), 28, 28, "/Users/chengli/tmp/kmeans_demo/clusters/iter_" + i + "/cluster_" + (k + 1) + "/center.png");
        // for (int i=0;i<assignment.length;i++){
        // if (assignment[i]==k){
        // plot(dataSet.getRow(i), 28,28,
        // "/Users/chengli/tmp/kmeans_demo/clusters/iter_"+iter+"/cluster_"+(k+1)+"/pic_"+(i+1)+".png");
        // }
        // }
        }
    }
    for (int k = 0; k < numComponents; k++) {
        System.out.println("component " + k);
        System.out.println("mean=" + gmm.getGaussianDistributions()[k].getMean());
        System.out.println("log determinant =" + gmm.getGaussianDistributions()[k].getLogDeterminant());
    }
// double[][] gammas = trainer.getGammas();
// double[] entropies = IntStream.range(0,rows).mapToDouble(i->Entropy.entropy(gammas[i])).toArray();
// System.out.println(Arrays.toString(entropies));
// int max = ArgMax.argMax(entropies);
// plot(dataSet.getRow(max), 28,28, "/Users/chengli/tmp/kmeans_demo/clusters/max_entropy.png");
// System.out.println(Arrays.toString(gammas[max]));
// System.out.println(gmm);
}
Also used : IntStream(java.util.stream.IntStream) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) java.util(java.util) BufferedImage(java.awt.image.BufferedImage) ArgMax(edu.neu.ccs.pyramid.util.ArgMax) FileUtils(org.apache.commons.io.FileUtils) RealVector(org.apache.commons.math3.linear.RealVector) File(java.io.File) KMeans(edu.neu.ccs.pyramid.clustering.kmeans.KMeans) Serialization(edu.neu.ccs.pyramid.util.Serialization) DescriptiveStatistics(org.apache.commons.math3.stat.descriptive.DescriptiveStatistics) Entropy(edu.neu.ccs.pyramid.eval.Entropy) ImageIO(javax.imageio.ImageIO) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) RealMatrix(org.apache.commons.math3.linear.RealMatrix) BM(edu.neu.ccs.pyramid.clustering.bm.BM) BMSelector(edu.neu.ccs.pyramid.clustering.bm.BMSelector) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) File(java.io.File)

Example 69 with Array2DRowRealMatrix

use of org.apache.commons.math3.linear.Array2DRowRealMatrix in project jstructure by JonStargaryen.

the class MultiDimensionalScaling method computeConfiguration.

private RealMatrix computeConfiguration(RealMatrix proximityMap) {
    RealMatrix centeringMap = new Array2DRowRealMatrix(numberOfDataPoints, numberOfDataPoints);
    double[] rowAverage = new double[numberOfDataPoints];
    double[] columnAverage = new double[numberOfDataPoints];
    double overallAverage = 0;
    // assess rows and overall average
    for (int row = 0; row < numberOfDataPoints; row++) {
        double tempRowAverage = 0;
        for (int column = 0; column < numberOfDataPoints; column++) {
            double entry = proximityMap.getEntry(row, column);
            tempRowAverage += entry;
            overallAverage += entry;
        }
        rowAverage[row] = tempRowAverage / numberOfDataPoints;
    }
    overallAverage /= numberOfDataPoints * numberOfDataPoints;
    // assess columns
    for (int column = 0; column < numberOfDataPoints; column++) {
        double tempColumnAverage = 0;
        for (int row = 0; row < numberOfDataPoints; row++) {
            tempColumnAverage += proximityMap.getEntry(row, column);
        }
        columnAverage[column] = tempColumnAverage / numberOfDataPoints;
    }
    for (int row = 0; row < numberOfDataPoints; row++) {
        for (int column = 0; column < numberOfDataPoints; column++) {
            // b_ij = a_ij - a_i* - a_j* + a_**
            centeringMap.setEntry(row, column, proximityMap.getEntry(row, column) - rowAverage[row] - columnAverage[column] + overallAverage);
        }
    }
    return centeringMap;
}
Also used : Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix)

Example 70 with Array2DRowRealMatrix

use of org.apache.commons.math3.linear.Array2DRowRealMatrix in project gatk by broadinstitute.

the class CopyRatioSegmenterUnitTest method testChromosomesOnDifferentSegments.

@Test
public void testChromosomesOnDifferentSegments() {
    final RandomGenerator rng = RandomGeneratorFactory.createRandomGenerator(new Random(563));
    final double[] trueLog2CopyRatios = new double[] { -2.0, 0.0, 1.7 };
    final double trueMemoryLength = 1e5;
    final double trueStandardDeviation = 0.2;
    // randomly set positions
    final int chainLength = 100;
    final List<SimpleInterval> positions = randomPositions("chr1", chainLength, rng, trueMemoryLength / 4);
    positions.addAll(randomPositions("chr2", chainLength, rng, trueMemoryLength / 4));
    positions.addAll(randomPositions("chr3", chainLength, rng, trueMemoryLength / 4));
    //fix everything to the same state 2
    final int trueState = 2;
    final List<Double> data = new ArrayList<>();
    for (int n = 0; n < positions.size(); n++) {
        final double copyRatio = trueLog2CopyRatios[trueState];
        final double observed = generateData(trueStandardDeviation, copyRatio, rng);
        data.add(observed);
    }
    final List<Target> targets = positions.stream().map(Target::new).collect(Collectors.toList());
    final ReadCountCollection rcc = new ReadCountCollection(targets, Arrays.asList("SAMPLE"), new Array2DRowRealMatrix(data.stream().mapToDouble(x -> x).toArray()));
    final CopyRatioSegmenter segmenter = new CopyRatioSegmenter(10, rcc);
    final List<ModeledSegment> segments = segmenter.getModeledSegments();
    //check that each chromosome has at least one segment
    final int numDifferentContigsInSegments = (int) segments.stream().map(ModeledSegment::getContig).distinct().count();
    Assert.assertEquals(numDifferentContigsInSegments, 3);
}
Also used : IntStream(java.util.stream.IntStream) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) java.util(java.util) ReadCountCollection(org.broadinstitute.hellbender.tools.exome.ReadCountCollection) ModeledSegment(org.broadinstitute.hellbender.tools.exome.ModeledSegment) Assert(org.testng.Assert) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) RandomGeneratorFactory(org.apache.commons.math3.random.RandomGeneratorFactory) Target(org.broadinstitute.hellbender.tools.exome.Target) Test(org.testng.annotations.Test) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) ReadCountCollection(org.broadinstitute.hellbender.tools.exome.ReadCountCollection) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) Target(org.broadinstitute.hellbender.tools.exome.Target) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) ModeledSegment(org.broadinstitute.hellbender.tools.exome.ModeledSegment) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Test(org.testng.annotations.Test)

Aggregations

Array2DRowRealMatrix (org.apache.commons.math3.linear.Array2DRowRealMatrix)141 RealMatrix (org.apache.commons.math3.linear.RealMatrix)101 Test (org.testng.annotations.Test)60 IntStream (java.util.stream.IntStream)31 BaseTest (org.broadinstitute.hellbender.utils.test.BaseTest)28 File (java.io.File)27 Collectors (java.util.stream.Collectors)25 ArrayList (java.util.ArrayList)24 Assert (org.testng.Assert)24 List (java.util.List)22 SimpleInterval (org.broadinstitute.hellbender.utils.SimpleInterval)22 Target (org.broadinstitute.hellbender.tools.exome.Target)18 java.util (java.util)15 Random (java.util.Random)14 ReadCountCollection (org.broadinstitute.hellbender.tools.exome.ReadCountCollection)14 ParamUtils (org.broadinstitute.hellbender.utils.param.ParamUtils)14 DataProvider (org.testng.annotations.DataProvider)14 Stream (java.util.stream.Stream)13 Arrays (java.util.Arrays)12 DoubleStream (java.util.stream.DoubleStream)12