use of org.apache.spark.mllib.linalg.Matrix in project gatk by broadinstitute.
the class PCATangentNormalizationUtils method tangentNormalizeSpark.
/**
* Tangent normalize given the raw PoN data using Spark: the code here is a little more complex for optimization purposes.
*
* Please see notes in docs/PoN ...
*
* Ahat^T = (C^T P^T) A^T
* Therefore, C^T is the RowMatrix
*
* pinv: P
* panel: A
* projection: Ahat
* cases: C
* betahat: C^T P^T
* tangentNormalizedCounts: C - Ahat
*/
private static PCATangentNormalizationResult tangentNormalizeSpark(final ReadCountCollection targetFactorNormalizedCounts, final RealMatrix reducedPanelCounts, final RealMatrix reducedPanelPInvCounts, final CaseToPoNTargetMapper targetMapper, final RealMatrix tangentNormalizationInputCounts, final JavaSparkContext ctx) {
// Make the C^T a distributed matrix (RowMatrix)
final RowMatrix caseTDistMat = SparkConverter.convertRealMatrixToSparkRowMatrix(ctx, tangentNormalizationInputCounts.transpose(), TN_NUM_SLICES_SPARK);
// Spark local matrices (transposed)
final Matrix pinvTLocalMat = new DenseMatrix(reducedPanelPInvCounts.getRowDimension(), reducedPanelPInvCounts.getColumnDimension(), Doubles.concat(reducedPanelPInvCounts.getData()), true).transpose();
final Matrix panelTLocalMat = new DenseMatrix(reducedPanelCounts.getRowDimension(), reducedPanelCounts.getColumnDimension(), Doubles.concat(reducedPanelCounts.getData()), true).transpose();
// Calculate the projection transpose in a distributed matrix, then convert to Apache Commons matrix (not transposed)
final RowMatrix betahatDistMat = caseTDistMat.multiply(pinvTLocalMat);
final RowMatrix projectionTDistMat = betahatDistMat.multiply(panelTLocalMat);
final RealMatrix projection = SparkConverter.convertSparkRowMatrixToRealMatrix(projectionTDistMat, tangentNormalizationInputCounts.transpose().getRowDimension()).transpose();
// Subtract the projection from the cases
final RealMatrix tangentNormalizedCounts = tangentNormalizationInputCounts.subtract(projection);
// Construct the result object and return it with the correct targets.
final ReadCountCollection tangentNormalized = targetMapper.fromPoNtoCaseCountCollection(tangentNormalizedCounts, targetFactorNormalizedCounts.columnNames());
final ReadCountCollection preTangentNormalized = targetMapper.fromPoNtoCaseCountCollection(tangentNormalizationInputCounts, targetFactorNormalizedCounts.columnNames());
final RealMatrix tangentBetaHats = SparkConverter.convertSparkRowMatrixToRealMatrix(betahatDistMat, tangentNormalizedCounts.getColumnDimension());
return new PCATangentNormalizationResult(tangentNormalized, preTangentNormalized, tangentBetaHats.transpose(), targetFactorNormalizedCounts);
}
use of org.apache.spark.mllib.linalg.Matrix in project gatk-protected by broadinstitute.
the class PCATangentNormalizationUtils method composeTangentNormalizationInputMatrix.
/**
* Prepares the data to perform tangent normalization.
* <p>
* This is done by count group or column:
* <ol>
* </li>we divide counts by the column mean,</li>
* </li>then we transform value to their log_2,</li>
* </li>and finally we center them around the median.</li>
* </ol>
* </p>
*
* @param matrix input matrix.
* @return never {@code null}.
*/
private static RealMatrix composeTangentNormalizationInputMatrix(final RealMatrix matrix) {
final RealMatrix result = matrix.copy();
// step 1: divide by column means and log_2 transform
final double[] columnMeans = GATKProtectedMathUtils.columnMeans(matrix);
result.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() {
@Override
public double visit(final int row, final int column, final double value) {
return truncatedLog2(value / columnMeans[column]);
}
});
// step 2: subtract column medians
final double[] columnMedians = IntStream.range(0, matrix.getColumnDimension()).mapToDouble(c -> new Median().evaluate(result.getColumn(c))).toArray();
result.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() {
@Override
public double visit(final int row, final int column, final double value) {
return value - columnMedians[column];
}
});
return result;
}
use of org.apache.spark.mllib.linalg.Matrix in project gatk by broadinstitute.
the class PCATangentNormalizationUtils method composeTangentNormalizationInputMatrix.
/**
* Prepares the data to perform tangent normalization.
* <p>
* This is done by count group or column:
* <ol>
* </li>we divide counts by the column mean,</li>
* </li>then we transform value to their log_2,</li>
* </li>and finally we center them around the median.</li>
* </ol>
* </p>
*
* @param matrix input matrix.
* @return never {@code null}.
*/
private static RealMatrix composeTangentNormalizationInputMatrix(final RealMatrix matrix) {
final RealMatrix result = matrix.copy();
// step 1: divide by column means and log_2 transform
final double[] columnMeans = GATKProtectedMathUtils.columnMeans(matrix);
result.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() {
@Override
public double visit(final int row, final int column, final double value) {
return truncatedLog2(value / columnMeans[column]);
}
});
// step 2: subtract column medians
final double[] columnMedians = IntStream.range(0, matrix.getColumnDimension()).mapToDouble(c -> new Median().evaluate(result.getColumn(c))).toArray();
result.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() {
@Override
public double visit(final int row, final int column, final double value) {
return value - columnMedians[column];
}
});
return result;
}
use of org.apache.spark.mllib.linalg.Matrix in project gatk-protected by broadinstitute.
the class PCATangentNormalizationUtils method tangentNormalizeSpark.
/**
* Tangent normalize given the raw PoN data using Spark: the code here is a little more complex for optimization purposes.
*
* Please see notes in docs/PoN ...
*
* Ahat^T = (C^T P^T) A^T
* Therefore, C^T is the RowMatrix
*
* pinv: P
* panel: A
* projection: Ahat
* cases: C
* betahat: C^T P^T
* tangentNormalizedCounts: C - Ahat
*/
private static PCATangentNormalizationResult tangentNormalizeSpark(final ReadCountCollection targetFactorNormalizedCounts, final RealMatrix reducedPanelCounts, final RealMatrix reducedPanelPInvCounts, final CaseToPoNTargetMapper targetMapper, final RealMatrix tangentNormalizationInputCounts, final JavaSparkContext ctx) {
// Make the C^T a distributed matrix (RowMatrix)
final RowMatrix caseTDistMat = SparkConverter.convertRealMatrixToSparkRowMatrix(ctx, tangentNormalizationInputCounts.transpose(), TN_NUM_SLICES_SPARK);
// Spark local matrices (transposed)
final Matrix pinvTLocalMat = new DenseMatrix(reducedPanelPInvCounts.getRowDimension(), reducedPanelPInvCounts.getColumnDimension(), Doubles.concat(reducedPanelPInvCounts.getData()), true).transpose();
final Matrix panelTLocalMat = new DenseMatrix(reducedPanelCounts.getRowDimension(), reducedPanelCounts.getColumnDimension(), Doubles.concat(reducedPanelCounts.getData()), true).transpose();
// Calculate the projection transpose in a distributed matrix, then convert to Apache Commons matrix (not transposed)
final RowMatrix betahatDistMat = caseTDistMat.multiply(pinvTLocalMat);
final RowMatrix projectionTDistMat = betahatDistMat.multiply(panelTLocalMat);
final RealMatrix projection = SparkConverter.convertSparkRowMatrixToRealMatrix(projectionTDistMat, tangentNormalizationInputCounts.transpose().getRowDimension()).transpose();
// Subtract the projection from the cases
final RealMatrix tangentNormalizedCounts = tangentNormalizationInputCounts.subtract(projection);
// Construct the result object and return it with the correct targets.
final ReadCountCollection tangentNormalized = targetMapper.fromPoNtoCaseCountCollection(tangentNormalizedCounts, targetFactorNormalizedCounts.columnNames());
final ReadCountCollection preTangentNormalized = targetMapper.fromPoNtoCaseCountCollection(tangentNormalizationInputCounts, targetFactorNormalizedCounts.columnNames());
final RealMatrix tangentBetaHats = SparkConverter.convertSparkRowMatrixToRealMatrix(betahatDistMat, tangentNormalizedCounts.getColumnDimension());
return new PCATangentNormalizationResult(tangentNormalized, preTangentNormalized, tangentBetaHats.transpose(), targetFactorNormalizedCounts);
}
Aggregations