use of org.apache.commons.math3.util.Pair in project gatk by broadinstitute.
the class TargetCoverageSexGenotypeCalculator method processReadCountsAndTargets.
/**
* Processes raw read counts and targets:
* <dl>
* <dt> If more than one sample is present in the collection, filters out fully uncovered targets
* from read counts and removes the uncovered targets from the target list</dt>
*
* <dt> Otherwise, does nothing and warns the user
* </dt>
* </dl>
*
* @param rawReadCounts raw read count collection
* @param targetList user provided target list
* @return pair of processed read counts and targets
*/
private ImmutablePair<ReadCountCollection, List<Target>> processReadCountsAndTargets(@Nonnull final ReadCountCollection rawReadCounts, @Nonnull final List<Target> targetList) {
final ReadCountCollection finalReadCounts;
final List<Target> finalTargetList;
/* remove totally uncovered targets */
if (rawReadCounts.columnNames().size() > 1) {
finalReadCounts = ReadCountCollectionUtils.removeTotallyUncoveredTargets(rawReadCounts, logger);
final Set<Target> targetSetFromProcessedReadCounts = new HashSet<>(finalReadCounts.targets());
finalTargetList = targetList.stream().filter(targetSetFromProcessedReadCounts::contains).collect(Collectors.toList());
} else {
final long numUncoveredTargets = rawReadCounts.records().stream().filter(rec -> (int) rec.getDouble(0) == 0).count();
final long numAllTargets = rawReadCounts.targets().size();
logger.info("Since only one sample is given for genotyping, the user is responsible for asserting" + " the aptitude of targets. Fully uncovered (irrelevant) targets can not be automatically" + " identified (total targets: " + numAllTargets + ", uncovered targets: " + numUncoveredTargets + ")");
finalReadCounts = rawReadCounts;
finalTargetList = targetList;
}
return ImmutablePair.of(finalReadCounts, finalTargetList);
}
use of org.apache.commons.math3.util.Pair in project gatk-protected by broadinstitute.
the class CoverageModelParameters method adaptModelToReadCountCollection.
/**
* This method "adapts" a model to a read count collection in the following sense:
*
* - removes targets that are not included in the model from the read counts collection
* - removes targets that are in the read count collection from the model
* - rearranges model targets in the same order as read count collection targets
*
* The modifications are not done in-place and the original input parameters remain intact.
*
* @param model a model
* @param readCounts a read count collection
* @return a pair of model and read count collection
*/
public static ImmutablePair<CoverageModelParameters, ReadCountCollection> adaptModelToReadCountCollection(@Nonnull final CoverageModelParameters model, @Nonnull final ReadCountCollection readCounts, @Nonnull final Logger logger) {
logger.info("Adapting model to read counts...");
Utils.nonNull(model, "The model parameters must be non-null");
Utils.nonNull(readCounts, "The read count collection must be non-null");
Utils.nonNull(logger, "The logger must be non-null");
final List<Target> modelTargetList = model.getTargetList();
final List<Target> readCountsTargetList = readCounts.targets();
final Set<Target> mutualTargetSet = Sets.intersection(new HashSet<>(modelTargetList), new HashSet<>(readCountsTargetList));
final List<Target> mutualTargetList = readCountsTargetList.stream().filter(mutualTargetSet::contains).collect(Collectors.toList());
logger.info("Number of mutual targets: " + mutualTargetList.size());
Utils.validateArg(mutualTargetList.size() > 0, "The intersection between model targets and targets from read count" + " collection is empty. Please check there the model is compatible with the given read count" + " collection.");
if (modelTargetList.size() > mutualTargetList.size()) {
logger.info("The following targets dropped from the model: " + Sets.difference(new HashSet<>(modelTargetList), mutualTargetSet).stream().map(Target::getName).collect(Collectors.joining(", ", "[", "]")));
}
if (readCountsTargetList.size() > mutualTargetList.size()) {
logger.info("The following targets dropped from read counts: " + Sets.difference(new HashSet<>(readCountsTargetList), mutualTargetSet).stream().map(Target::getName).collect(Collectors.joining(", ", "[", "]")));
}
/* the targets in {@code subsetReadCounts} follow the original order of targets in {@code readCounts} */
final ReadCountCollection subsetReadCounts = readCounts.subsetTargets(mutualTargetSet);
/* fetch original model parameters */
final INDArray originalModelTargetMeanBias = model.getTargetMeanLogBias();
final INDArray originalModelTargetUnexplainedVariance = model.getTargetUnexplainedVariance();
final INDArray originalModelMeanBiasCovariates = model.getMeanBiasCovariates();
/* re-arrange targets, mean log bias, and target-specific unexplained variance */
final Map<Target, Integer> modelTargetsToIndexMap = IntStream.range(0, modelTargetList.size()).mapToObj(ti -> ImmutablePair.of(modelTargetList.get(ti), ti)).collect(Collectors.toMap(Pair<Target, Integer>::getLeft, Pair<Target, Integer>::getRight));
final int[] newTargetIndicesInOriginalModel = mutualTargetList.stream().mapToInt(modelTargetsToIndexMap::get).toArray();
final INDArray newModelTargetMeanBias = Nd4j.create(new int[] { 1, mutualTargetList.size() });
final INDArray newModelTargetUnexplainedVariance = Nd4j.create(new int[] { 1, mutualTargetList.size() });
IntStream.range(0, mutualTargetList.size()).forEach(ti -> {
newModelTargetMeanBias.put(0, ti, originalModelTargetMeanBias.getDouble(0, newTargetIndicesInOriginalModel[ti]));
newModelTargetUnexplainedVariance.put(0, ti, originalModelTargetUnexplainedVariance.getDouble(0, newTargetIndicesInOriginalModel[ti]));
});
/* if model has bias covariates and/or ARD, re-arrange mean/var of bias covariates as well */
final INDArray newModelMeanBiasCovariates;
if (model.isBiasCovariatesEnabled()) {
newModelMeanBiasCovariates = Nd4j.create(new int[] { mutualTargetList.size(), model.getNumLatents() });
IntStream.range(0, mutualTargetList.size()).forEach(ti -> {
newModelMeanBiasCovariates.get(NDArrayIndex.point(ti), NDArrayIndex.all()).assign(originalModelMeanBiasCovariates.get(NDArrayIndex.point(newTargetIndicesInOriginalModel[ti]), NDArrayIndex.all()));
});
} else {
newModelMeanBiasCovariates = null;
}
return ImmutablePair.of(new CoverageModelParameters(mutualTargetList, newModelTargetMeanBias, newModelTargetUnexplainedVariance, newModelMeanBiasCovariates, model.getBiasCovariateARDCoefficients()), subsetReadCounts);
}
use of org.apache.commons.math3.util.Pair in project gatk by broadinstitute.
the class CoverageModelParameters method adaptModelToReadCountCollection.
/**
* This method "adapts" a model to a read count collection in the following sense:
*
* - removes targets that are not included in the model from the read counts collection
* - removes targets that are in the read count collection from the model
* - rearranges model targets in the same order as read count collection targets
*
* The modifications are not done in-place and the original input parameters remain intact.
*
* @param model a model
* @param readCounts a read count collection
* @return a pair of model and read count collection
*/
public static ImmutablePair<CoverageModelParameters, ReadCountCollection> adaptModelToReadCountCollection(@Nonnull final CoverageModelParameters model, @Nonnull final ReadCountCollection readCounts, @Nonnull final Logger logger) {
logger.info("Adapting model to read counts...");
Utils.nonNull(model, "The model parameters must be non-null");
Utils.nonNull(readCounts, "The read count collection must be non-null");
Utils.nonNull(logger, "The logger must be non-null");
final List<Target> modelTargetList = model.getTargetList();
final List<Target> readCountsTargetList = readCounts.targets();
final Set<Target> mutualTargetSet = Sets.intersection(new HashSet<>(modelTargetList), new HashSet<>(readCountsTargetList));
final List<Target> mutualTargetList = readCountsTargetList.stream().filter(mutualTargetSet::contains).collect(Collectors.toList());
logger.info("Number of mutual targets: " + mutualTargetList.size());
Utils.validateArg(mutualTargetList.size() > 0, "The intersection between model targets and targets from read count" + " collection is empty. Please check there the model is compatible with the given read count" + " collection.");
if (modelTargetList.size() > mutualTargetList.size()) {
logger.info("The following targets dropped from the model: " + Sets.difference(new HashSet<>(modelTargetList), mutualTargetSet).stream().map(Target::getName).collect(Collectors.joining(", ", "[", "]")));
}
if (readCountsTargetList.size() > mutualTargetList.size()) {
logger.info("The following targets dropped from read counts: " + Sets.difference(new HashSet<>(readCountsTargetList), mutualTargetSet).stream().map(Target::getName).collect(Collectors.joining(", ", "[", "]")));
}
/* the targets in {@code subsetReadCounts} follow the original order of targets in {@code readCounts} */
final ReadCountCollection subsetReadCounts = readCounts.subsetTargets(mutualTargetSet);
/* fetch original model parameters */
final INDArray originalModelTargetMeanBias = model.getTargetMeanLogBias();
final INDArray originalModelTargetUnexplainedVariance = model.getTargetUnexplainedVariance();
final INDArray originalModelMeanBiasCovariates = model.getMeanBiasCovariates();
/* re-arrange targets, mean log bias, and target-specific unexplained variance */
final Map<Target, Integer> modelTargetsToIndexMap = IntStream.range(0, modelTargetList.size()).mapToObj(ti -> ImmutablePair.of(modelTargetList.get(ti), ti)).collect(Collectors.toMap(Pair<Target, Integer>::getLeft, Pair<Target, Integer>::getRight));
final int[] newTargetIndicesInOriginalModel = mutualTargetList.stream().mapToInt(modelTargetsToIndexMap::get).toArray();
final INDArray newModelTargetMeanBias = Nd4j.create(new int[] { 1, mutualTargetList.size() });
final INDArray newModelTargetUnexplainedVariance = Nd4j.create(new int[] { 1, mutualTargetList.size() });
IntStream.range(0, mutualTargetList.size()).forEach(ti -> {
newModelTargetMeanBias.put(0, ti, originalModelTargetMeanBias.getDouble(0, newTargetIndicesInOriginalModel[ti]));
newModelTargetUnexplainedVariance.put(0, ti, originalModelTargetUnexplainedVariance.getDouble(0, newTargetIndicesInOriginalModel[ti]));
});
/* if model has bias covariates and/or ARD, re-arrange mean/var of bias covariates as well */
final INDArray newModelMeanBiasCovariates;
if (model.isBiasCovariatesEnabled()) {
newModelMeanBiasCovariates = Nd4j.create(new int[] { mutualTargetList.size(), model.getNumLatents() });
IntStream.range(0, mutualTargetList.size()).forEach(ti -> {
newModelMeanBiasCovariates.get(NDArrayIndex.point(ti), NDArrayIndex.all()).assign(originalModelMeanBiasCovariates.get(NDArrayIndex.point(newTargetIndicesInOriginalModel[ti]), NDArrayIndex.all()));
});
} else {
newModelMeanBiasCovariates = null;
}
return ImmutablePair.of(new CoverageModelParameters(mutualTargetList, newModelTargetMeanBias, newModelTargetUnexplainedVariance, newModelMeanBiasCovariates, model.getBiasCovariateARDCoefficients()), subsetReadCounts);
}
use of org.apache.commons.math3.util.Pair in project gatk by broadinstitute.
the class HDF5PCACoveragePoNCreationUtilsUnitTest method testSubsetTargetToUsableOnes.
@Test(dataProvider = "readCountAndPercentileData")
public void testSubsetTargetToUsableOnes(final ReadCountCollection readCount, final double percentile) {
final Median median = new Median();
final RealMatrix counts = readCount.counts();
final double[] targetMedians = IntStream.range(0, counts.getRowDimension()).mapToDouble(i -> median.evaluate(counts.getRow(i))).toArray();
final double threshold = new Percentile(percentile).evaluate(targetMedians);
final Boolean[] toBeKept = DoubleStream.of(targetMedians).mapToObj(d -> d >= threshold).toArray(Boolean[]::new);
final int toBeKeptCount = (int) Stream.of(toBeKept).filter(b -> b).count();
final Pair<ReadCountCollection, double[]> result = HDF5PCACoveragePoNCreationUtils.subsetReadCountsToUsableTargets(readCount, percentile, NULL_LOGGER);
Assert.assertEquals(result.getLeft().targets().size(), toBeKeptCount);
Assert.assertEquals(result.getRight().length, toBeKeptCount);
int nextIndex = 0;
for (int i = 0; i < toBeKept.length; i++) {
if (toBeKept[i]) {
int index = result.getLeft().targets().indexOf(readCount.targets().get(i));
Assert.assertEquals(index, nextIndex++);
Assert.assertEquals(counts.getRow(i), result.getLeft().counts().getRow(index));
Assert.assertEquals(result.getRight()[index], targetMedians[i]);
} else {
Assert.assertEquals(result.getLeft().targets().indexOf(readCount.targets().get(i)), -1);
}
}
}
Aggregations