Search in sources :

Example 41 with BiFunction

use of java.util.function.BiFunction in project gatk-protected by broadinstitute.

the class CoverageModelEMWorkspace method getCopyRatioSegmentsSpark.

/**
     * Fetch copy ratio segments from compute blocks (Spark implementation)
     *
     * @return a list of {@link CopyRatioHMMResults}
     */
private List<List<HiddenStateSegmentRecord<STATE, Target>>> getCopyRatioSegmentsSpark() {
    /* local final member variables for lambda capture */
    final List<Target> processedTargetList = new ArrayList<>();
    processedTargetList.addAll(this.processedTargetList);
    final List<SexGenotypeData> processedSampleSexGenotypeData = new ArrayList<>();
    processedSampleSexGenotypeData.addAll(this.processedSampleSexGenotypeData);
    final List<String> processedSampleNameList = new ArrayList<>();
    processedSampleNameList.addAll(this.processedSampleNameList);
    final INDArray sampleReadDepths = Transforms.exp(sampleMeanLogReadDepths, true);
    final CopyRatioExpectationsCalculator<CoverageModelCopyRatioEmissionData, STATE> copyRatioExpectationsCalculator = this.copyRatioExpectationsCalculator;
    final BiFunction<SexGenotypeData, Target, STATE> referenceStateFactory = this.referenceStateFactory;
    return fetchCopyRatioEmissionDataSpark().mapPartitionsToPair(it -> {
        final List<Tuple2<Integer, CopyRatioHMMResults<CoverageModelCopyRatioEmissionData, STATE>>> newPartitionData = new ArrayList<>();
        while (it.hasNext()) {
            final Tuple2<Integer, List<CoverageModelCopyRatioEmissionData>> prevDatum = it.next();
            final int sampleIndex = prevDatum._1;
            final CopyRatioCallingMetadata copyRatioCallingMetadata = CopyRatioCallingMetadata.builder().sampleName(processedSampleNameList.get(sampleIndex)).sampleSexGenotypeData(processedSampleSexGenotypeData.get(sampleIndex)).sampleCoverageDepth(sampleReadDepths.getDouble(sampleIndex)).emissionCalculationStrategy(EmissionCalculationStrategy.HYBRID_POISSON_GAUSSIAN).build();
            newPartitionData.add(new Tuple2<>(sampleIndex, copyRatioExpectationsCalculator.getCopyRatioHMMResults(copyRatioCallingMetadata, processedTargetList, prevDatum._2)));
        }
        return newPartitionData.iterator();
    }, true).mapPartitionsToPair(it -> {
        final List<Tuple2<Integer, List<HiddenStateSegmentRecord<STATE, Target>>>> newPartitionData = new ArrayList<>();
        while (it.hasNext()) {
            final Tuple2<Integer, CopyRatioHMMResults<CoverageModelCopyRatioEmissionData, STATE>> prevDatum = it.next();
            final int sampleIndex = prevDatum._1;
            final CopyRatioHMMResults<CoverageModelCopyRatioEmissionData, STATE> result = prevDatum._2;
            final HMMSegmentProcessor<CoverageModelCopyRatioEmissionData, STATE, Target> processor = new HMMSegmentProcessor<>(Collections.singletonList(result.getMetaData().getSampleName()), Collections.singletonList(result.getMetaData().getSampleSexGenotypeData()), referenceStateFactory, Collections.singletonList(new HashedListTargetCollection<>(processedTargetList)), Collections.singletonList(result.getForwardBackwardResult()), Collections.singletonList(result.getViterbiResult()));
            newPartitionData.add(new Tuple2<>(sampleIndex, processor.getSegmentsAsList()));
        }
        return newPartitionData.iterator();
    }).collect().stream().sorted(Comparator.comparingInt(t -> t._1)).map(t -> t._2).collect(Collectors.toList());
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Tuple2(scala.Tuple2) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor)

Example 42 with BiFunction

use of java.util.function.BiFunction in project gatk-protected by broadinstitute.

the class ReadCountRecordUnitTest method testAppendCountsTo.

@Test(dataProvider = "testData", dependsOnMethods = "testCreation")
public void testAppendCountsTo(@SuppressWarnings("unused") final String testName, final BiFunction<Target, double[], ? extends ReadCountRecord> constructor, final int size) {
    final double[] counts = generateCounts(size);
    final boolean round = testName.equals("long[]");
    final ReadCountRecord record = constructor.apply(TEST_TARGET, counts);
    final List<String> columnNames = Stream.concat(Stream.concat(IntStream.range(0, 10).mapToObj(i -> "pre-padding_" + i), IntStream.range(0, counts.length).mapToObj(i -> "column_" + i)), IntStream.range(0, 10).mapToObj(i -> "post-padding_" + i)).collect(Collectors.toList());
    final TableColumnCollection columns = new TableColumnCollection(columnNames);
    final DataLine dataLine = new DataLine(columns, RuntimeException::new);
    final double[] copiedCounts = new double[counts.length + 20];
    Arrays.fill(copiedCounts, -11);
    for (int i = 0; i < 10 + 10 + counts.length; i++) {
        dataLine.append("-11");
    }
    dataLine.seek(10);
    record.appendCountsTo(dataLine);
    // Check the copied values.
    if (!round) {
        for (int i = 0; i < counts.length; i++) {
            Assert.assertEquals(dataLine.getDouble(10 + i), counts[i], 0.0);
        }
    } else {
        for (int i = 0; i < counts.length; i++) {
            Assert.assertEquals(dataLine.getDouble(10 + i), Math.round(counts[i]), 0.00001);
        }
    }
    // Check that the padding remains intact:
    for (int i = 0; i < 10; i++) {
        Assert.assertEquals(dataLine.get(i), "-11");
    }
    for (int i = counts.length + 10; i < copiedCounts.length; i++) {
        Assert.assertEquals(dataLine.get(i), "-11");
    }
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) DataProvider(org.testng.annotations.DataProvider) BiFunction(java.util.function.BiFunction) Test(org.testng.annotations.Test) Random(java.util.Random) Collectors(java.util.stream.Collectors) DataLine(org.broadinstitute.hellbender.utils.tsv.DataLine) ArrayList(java.util.ArrayList) List(java.util.List) Stream(java.util.stream.Stream) Assert(org.testng.Assert) TableColumnCollection(org.broadinstitute.hellbender.utils.tsv.TableColumnCollection) DataLine(org.broadinstitute.hellbender.utils.tsv.DataLine) TableColumnCollection(org.broadinstitute.hellbender.utils.tsv.TableColumnCollection) Test(org.testng.annotations.Test)

Example 43 with BiFunction

use of java.util.function.BiFunction in project gatk-protected by broadinstitute.

the class ReadCountRecordUnitTest method testAppendCountsToBeyondEnd.

@Test(dataProvider = "testNonZeroCountsData", dependsOnMethods = "testAppendCountsTo", expectedExceptions = IllegalStateException.class)
public void testAppendCountsToBeyondEnd(@SuppressWarnings("unused") final String testName, final BiFunction<Target, double[], ? extends ReadCountRecord> constructor, final int size) {
    final double[] counts = generateCounts(size);
    final ReadCountRecord record = constructor.apply(TEST_TARGET, counts);
    final List<String> columnNames = Stream.concat(Stream.concat(IntStream.range(0, 10).mapToObj(i -> "pre-padding_" + i), IntStream.range(0, counts.length).mapToObj(i -> "column_" + i)), IntStream.range(0, 10).mapToObj(i -> "post-padding_" + i)).collect(Collectors.toList());
    final TableColumnCollection columns = new TableColumnCollection(columnNames);
    final DataLine dataLine = new DataLine(columns, RuntimeException::new);
    final double[] copiedCounts = new double[counts.length + 20];
    Arrays.fill(copiedCounts, -11);
    dataLine.seek(columnNames.size());
    record.appendCountsTo(dataLine);
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) DataProvider(org.testng.annotations.DataProvider) BiFunction(java.util.function.BiFunction) Test(org.testng.annotations.Test) Random(java.util.Random) Collectors(java.util.stream.Collectors) DataLine(org.broadinstitute.hellbender.utils.tsv.DataLine) ArrayList(java.util.ArrayList) List(java.util.List) Stream(java.util.stream.Stream) Assert(org.testng.Assert) TableColumnCollection(org.broadinstitute.hellbender.utils.tsv.TableColumnCollection) DataLine(org.broadinstitute.hellbender.utils.tsv.DataLine) TableColumnCollection(org.broadinstitute.hellbender.utils.tsv.TableColumnCollection) Test(org.testng.annotations.Test)

Aggregations

BiFunction (java.util.function.BiFunction)43 HashMap (java.util.HashMap)22 List (java.util.List)22 Map (java.util.Map)22 ArrayList (java.util.ArrayList)18 Test (org.junit.Test)17 Collectors (java.util.stream.Collectors)15 Collections (java.util.Collections)14 Set (java.util.Set)13 Function (java.util.function.Function)13 Arrays (java.util.Arrays)12 Mockito.mock (org.mockito.Mockito.mock)12 IOException (java.io.IOException)11 Collection (java.util.Collection)11 IntStream (java.util.stream.IntStream)11 Before (org.junit.Before)11 Logger (org.apache.logging.log4j.Logger)8 Config (org.apache.samza.config.Config)8 JobConfig (org.apache.samza.config.JobConfig)8 ApplicationRunner (org.apache.samza.runtime.ApplicationRunner)8