use of cern.colt.matrix.DoubleMatrix1D in project Gemma by PavlidisLab.
the class ComBat method plot.
/**
* Make diagnostic plots.
* FIXME: As in the original ComBat, this only graphs the first batch's statistics. In principle we can (and perhaps
* should) examine these plots for all the batches.
*
* @param filePrefix file prefix
*/
public void plot(String filePrefix) {
if (this.gammaHat == null)
throw new IllegalArgumentException("You must call 'run' first");
/*
* View the distribution of gammaHat, which we assume will have a normal distribution
*/
DoubleMatrix1D ghr = gammaHat.viewRow(0);
int NUM_HIST_BINS = 100;
Histogram gammaHatHist = new Histogram("GammaHat", NUM_HIST_BINS, ghr);
XYSeries ghplot = gammaHatHist.plot();
Normal rn = new Normal(this.gammaBar.get(0), Math.sqrt(this.t2.get(0)), new MersenneTwister());
Histogram ghtheoryT = new Histogram("Gamma", NUM_HIST_BINS, gammaHatHist.min(), gammaHatHist.max());
for (int i = 0; i < 10000; i++) {
double n = rn.nextDouble();
ghtheoryT.fill(n);
}
XYSeries ghtheory = ghtheoryT.plot();
File tmpfile;
try {
tmpfile = File.createTempFile(filePrefix + ".gammahat.histogram.", ".png");
ComBat.log.info(tmpfile);
} catch (IOException e) {
throw new RuntimeException(e);
}
try (OutputStream os = new FileOutputStream(tmpfile)) {
this.writePlot(os, ghplot, ghtheory);
/*
* View the distribution of deltaHat, which we assume has an inverse gamma distribution
*/
DoubleMatrix1D dhr = deltaHat.viewRow(0);
Histogram deltaHatHist = new Histogram("DeltaHat", NUM_HIST_BINS, dhr);
XYSeries dhplot = deltaHatHist.plot();
Gamma g = new Gamma(aPrior.get(0), bPrior.get(0), new MersenneTwister());
Histogram deltaHatT = new Histogram("Delta", NUM_HIST_BINS, deltaHatHist.min(), deltaHatHist.max());
for (int i = 0; i < 10000; i++) {
double invg = 1.0 / g.nextDouble();
deltaHatT.fill(invg);
}
XYSeries dhtheory = deltaHatT.plot();
tmpfile = File.createTempFile(filePrefix + ".deltahat.histogram.", ".png");
ComBat.log.info(tmpfile);
try (OutputStream os2 = new FileOutputStream(tmpfile)) {
this.writePlot(os2, dhplot, dhtheory);
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}
use of cern.colt.matrix.DoubleMatrix1D in project Gemma by PavlidisLab.
the class ComBat method rawAdjust.
private DoubleMatrix2D rawAdjust(DoubleMatrix2D sdata, DoubleMatrix2D gammastar, DoubleMatrix2D deltastar) {
int batchIndex;
int batchNum = 0;
DoubleMatrix2D adjustedData = new DenseDoubleMatrix2D(sdata.rows(), sdata.columns());
for (String batchId : batches.keySet()) {
DoubleMatrix2D batchData = this.getBatchData(sdata, batchId);
DoubleMatrix2D Xbb = this.getBatchDesign(batchId);
DoubleMatrix2D adjustedBatch = batchData.copy().assign(solver.transpose(solver.mult(Xbb, gammastar)), Functions.minus);
DoubleMatrix1D deltaStarRow = deltastar.viewRow(batchNum);
deltaStarRow.assign(Functions.sqrt);
DoubleMatrix1D ones = new DenseDoubleMatrix1D(batchData.columns());
ones.assign(1.0);
DoubleMatrix2D divisor = solver.multOuter(deltaStarRow, ones, null);
adjustedBatch.assign(divisor, Functions.div);
/*
* Now we have to put the data back in the right order -- the batches are all together.
*/
Map<C, Integer> locations = originalLocationsInMatrix.get(batchId);
for (batchIndex = 0; batchIndex < adjustedBatch.rows(); batchIndex++) {
int j = 0;
for (Integer index : locations.values()) {
adjustedData.set(batchIndex, index, adjustedBatch.get(batchIndex, j));
j++;
}
}
batchNum++;
}
return adjustedData;
}
use of cern.colt.matrix.DoubleMatrix1D in project Gemma by PavlidisLab.
the class ComBat method getBatchDesign.
private DoubleMatrix2D getBatchDesign(String batchId) {
Collection<C> sampleNames = batches.get(batchId);
DoubleMatrix2D result = new DenseDoubleMatrix2D(sampleNames.size(), batches.keySet().size());
for (int j = 0; j < batches.keySet().size(); j++) {
int i = 0;
for (C sname : sampleNames) {
DoubleMatrix1D rowInBatch = x.viewRow(data.getColIndexByName(sname));
result.set(i, j, rowInBatch.get(j));
i++;
}
}
// log.info( result );
return result;
}
use of cern.colt.matrix.DoubleMatrix1D in project Gemma by PavlidisLab.
the class ComBat method nonParametricFit.
private DoubleMatrix1D[] nonParametricFit(DoubleMatrix2D matrix, DoubleMatrix1D gHat, DoubleMatrix1D dHat) {
DoubleMatrix1D gstar = new DenseDoubleMatrix1D(matrix.rows());
DoubleMatrix1D dstar = new DenseDoubleMatrix1D(matrix.rows());
double twopi = 2.0 * Math.PI;
StopWatch timer = new StopWatch();
timer.start();
/*
* Vectorized schmectorized. In R you end up looping over the data many times. It's slow here too... but not too
* horrible. 1000 rows of a 10k probe data set with 10 samples takes about 7.5 seconds on my laptop -- but this
* has to be done for each batch. It's O( M*N^2 )
*/
int c = 1;
for (int i = 0; i < matrix.rows(); i++) {
double[] x = MatrixUtil.removeMissing(matrix.viewRow(i)).toArray();
int n = x.length;
double no2 = n / 2.0;
double sumLH = 0.0;
double sumgLH = 0.0;
double sumdLH = 0.0;
for (int j = 0; j < matrix.rows(); j++) {
if (j == i)
continue;
double g = gHat.getQuick(j);
double d = dHat.getQuick(j);
// compute the sum of squares of the difference between gHat[j] and the current data row.
// this is slower, though it's the "colt api" way.
// double sum2 = x.copy().assign( Functions.minus( g ) ).aggregate( Functions.plus, Functions.square );
double sum2 = 0.0;
for (double aX : x) {
sum2 += Math.pow(aX - g, 2);
}
double LH = (1.0 / Math.pow(twopi * d, no2)) * Math.exp(-sum2 / (2 * d));
if (Double.isNaN(LH))
continue;
double gLH = g * LH;
double dLH = d * LH;
sumLH += LH;
sumgLH += gLH;
sumdLH += dLH;
}
gstar.set(i, sumgLH / sumLH);
dstar.set(i, sumdLH / sumLH);
if (c++ % 1000 == 0) {
ComBat.log.info(i + String.format(" rows done, %.1fs elapsed", timer.getTime() / 1000.00));
}
}
return new DoubleMatrix1D[] { gstar, dstar };
}
use of cern.colt.matrix.DoubleMatrix1D in project Gemma by PavlidisLab.
the class ComBat method standardize.
/**
* Special standardization: partial regression of covariates
*
* @param b b
* @param A A
* @return double matrix 2d
*/
DoubleMatrix2D standardize(DoubleMatrix2D b, DoubleMatrix2D A) {
DoubleMatrix2D beta = new LeastSquaresFit(A, b).getCoefficients();
// assertEquals( 3.7805, beta.get( 0, 0 ), 0.001 );
// assertEquals( 0.0541, beta.get( 2, 18 ), 0.001 );
int batchIndex = 0;
DoubleMatrix2D bba = new DenseDoubleMatrix2D(1, numBatches);
for (String batchId : batches.keySet()) {
bba.set(0, batchIndex++, (double) batches.get(batchId).size() / numSamples);
}
/*
* Weight the non-batch coefficients by the batch sizes.
*/
DoubleMatrix2D grandMeanM = solver.mult(bba, beta.viewPart(0, 0, numBatches, beta.columns()));
if (hasMissing) {
varpooled = y.copy().assign(solver.transpose(solver.mult(x, beta)), Functions.minus);
DoubleMatrix2D var = new DenseDoubleMatrix2D(varpooled.rows(), 1);
for (int i = 0; i < varpooled.rows(); i++) {
DoubleMatrix1D row = varpooled.viewRow(i);
double m = DescriptiveWithMissing.mean(new DoubleArrayList(row.toArray()));
double v = DescriptiveWithMissing.sampleVariance(new DoubleArrayList(row.toArray()), m);
var.set(i, 0, v);
}
varpooled = var;
} else {
varpooled = y.copy().assign(solver.transpose(solver.mult(x, beta)), Functions.minus).assign(Functions.pow(2));
DoubleMatrix2D scale = new DenseDoubleMatrix2D(numSamples, 1);
scale.assign(1.0 / numSamples);
varpooled = solver.mult(varpooled, scale);
}
DoubleMatrix2D size = new DenseDoubleMatrix2D(numSamples, 1);
size.assign(1.0);
/*
* The coefficients repeated for each sample.
*/
standMean = solver.mult(solver.transpose(grandMeanM), solver.transpose(size));
/*
* Erase the batch factors from a copy of the design matrix
*/
DoubleMatrix2D tmpX = x.copy();
for (batchIndex = 0; batchIndex < numBatches; batchIndex++) {
for (int j = 0; j < x.rows(); j++) {
tmpX.set(j, batchIndex, 0.0);
}
}
/*
* row means, adjusted "per group", and ignoring batch effects.
*/
standMean = standMean.assign(solver.transpose(solver.mult(tmpX, beta)), Functions.plus);
DoubleMatrix2D varsq = solver.mult(varpooled.copy().assign(Functions.sqrt), solver.transpose(size));
/*
* Subtract the mean and divide by the standard deviations.
*/
DoubleMatrix2D meansubtracted = y.copy().assign(standMean, Functions.minus);
return meansubtracted.assign(varsq, Functions.div);
}
Aggregations