Search in sources :

Example 31 with Broadcast

use of org.apache.spark.broadcast.Broadcast in project bunsen by cerner.

the class BroadcastableMappings method broadcast.

/**
 * Broadcast mappings stored in the given conceptMaps instance that match the given
 * conceptMapUris.
 *
 * @param conceptMaps the {@link ConceptMaps} instance with the content to broadcast
 * @param conceptMapUriToVersion map of the concept map URIs to broadcast to their versions.
 * @return a broadcast variable containing a mappings object usable in UDFs.
 */
public static Broadcast<BroadcastableMappings> broadcast(ConceptMaps conceptMaps, Map<String, String> conceptMapUriToVersion) {
    Map<String, ConceptMap> mapsToLoad = conceptMaps.getMaps().collectAsList().stream().filter(conceptMap -> conceptMap.getVersion().equals(conceptMapUriToVersion.get(conceptMap.getUrl()))).collect(Collectors.toMap(ConceptMap::getUrl, Function.identity()));
    // Expand the concept maps to load and sort them so dependencies are before
    // their dependents in the list.
    List<String> sortedMapsToLoad = sortMapsToLoad(conceptMapUriToVersion.keySet(), mapsToLoad);
    // Since this is used to map from one system to another, we use only targets
    // that don't introduce inaccurate meanings. (For instance, we can't map
    // general condition code to a more specific type, since that is not
    // representative of the source data.)
    Dataset<Mapping> mappings = conceptMaps.getMappings(conceptMapUriToVersion).filter("equivalence in ('equivalent', 'equals', 'wider', 'subsumes')");
    // Group mappings by their concept map URI
    Map<String, List<Mapping>> groupedMappings = mappings.collectAsList().stream().collect(Collectors.groupingBy(Mapping::getConceptMapUri));
    Map<String, BroadcastableConceptMap> broadcastableMaps = new HashMap<>();
    for (String conceptMapUri : sortedMapsToLoad) {
        ConceptMap map = mapsToLoad.get(conceptMapUri);
        Set<String> children = getMapChildren(map);
        List<BroadcastableConceptMap> childMaps = children.stream().map(child -> broadcastableMaps.get(child)).collect(Collectors.toList());
        BroadcastableConceptMap broadcastableConceptMap = new BroadcastableConceptMap(conceptMapUri, groupedMappings.getOrDefault(conceptMapUri, Collections.emptyList()), childMaps);
        broadcastableMaps.put(conceptMapUri, broadcastableConceptMap);
    }
    JavaSparkContext ctx = new JavaSparkContext(conceptMaps.getMaps().sparkSession().sparkContext());
    return ctx.broadcast(new BroadcastableMappings(broadcastableMaps));
}
Also used : Broadcast(org.apache.spark.broadcast.Broadcast) Dataset(org.apache.spark.sql.Dataset) ConceptMaps(com.cerner.bunsen.codes.ConceptMaps) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Set(java.util.Set) HashMap(java.util.HashMap) Deque(java.util.Deque) Function(java.util.function.Function) Collectors(java.util.stream.Collectors) Serializable(java.io.Serializable) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) List(java.util.List) Map(java.util.Map) ConceptMapGroupUnmappedMode(org.hl7.fhir.dstu3.model.ConceptMap.ConceptMapGroupUnmappedMode) ArrayDeque(java.util.ArrayDeque) Collections(java.util.Collections) ConceptMap(org.hl7.fhir.dstu3.model.ConceptMap) Mapping(com.cerner.bunsen.codes.Mapping) SparkSession(org.apache.spark.sql.SparkSession) HashMap(java.util.HashMap) Mapping(com.cerner.bunsen.codes.Mapping) ArrayList(java.util.ArrayList) List(java.util.List) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) ConceptMap(org.hl7.fhir.dstu3.model.ConceptMap)

Example 32 with Broadcast

use of org.apache.spark.broadcast.Broadcast in project pyramid by cheng-li.

the class SparkCBMOptimizer method updateBinaryClassifiers.

private void updateBinaryClassifiers() {
    if (logger.isDebugEnabled()) {
        logger.debug("start updateBinaryClassifiers");
    }
    Classifier.ProbabilityEstimator[][] localBinaryClassifiers = cbm.binaryClassifiers;
    double[][] localGammasT = gammasT;
    Broadcast<MultiLabelClfDataSet> localDataSetBroadcast = dataSetBroadCast;
    Broadcast<double[][][]> localTargetsBroadcast = targetDisBroadCast;
    double localVariance = priorVarianceBinary;
    List<BinaryTask> binaryTaskList = new ArrayList<>();
    for (int k = 0; k < cbm.numComponents; k++) {
        for (int l = 0; l < cbm.numLabels; l++) {
            LogisticRegression logisticRegression = (LogisticRegression) localBinaryClassifiers[k][l];
            double[] weights = localGammasT[k];
            binaryTaskList.add(new BinaryTask(k, l, logisticRegression, weights));
        }
    }
    JavaRDD<BinaryTask> binaryTaskRDD = sparkContext.parallelize(binaryTaskList, binaryTaskList.size());
    List<BinaryTaskResult> results = binaryTaskRDD.map(binaryTask -> {
        int labelIndex = binaryTask.classIndex;
        // each element in rdd should contain its full information
        return updateBinaryLogisticRegression(binaryTask.componentIndex, binaryTask.classIndex, binaryTask.logisticRegression, localDataSetBroadcast.value(), binaryTask.weights, localTargetsBroadcast.value()[labelIndex], localVariance);
    }).collect();
    for (BinaryTaskResult result : results) {
        cbm.binaryClassifiers[result.componentIndex][result.classIndex] = result.binaryClassifier;
    }
    // IntStream.range(0, cbm.numComponents).forEach(this::updateBinaryClassifiers);
    if (logger.isDebugEnabled()) {
        logger.debug("finish updateBinaryClassifiers");
    }
}
Also used : IntStream(java.util.stream.IntStream) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) ArrayList(java.util.ArrayList) Classifier(edu.neu.ccs.pyramid.classification.Classifier) Terminator(edu.neu.ccs.pyramid.optimization.Terminator) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) RidgeLogisticOptimizer(edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) LKBoost(edu.neu.ccs.pyramid.classification.lkboost.LKBoost) LogisticLoss(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticLoss) JavaRDD(org.apache.spark.api.java.JavaRDD) Broadcast(org.apache.spark.broadcast.Broadcast) LKBoostOptimizer(edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) Serializable(java.io.Serializable) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) KLDivergence(edu.neu.ccs.pyramid.eval.KLDivergence) List(java.util.List) Logger(org.apache.logging.log4j.Logger) ElasticNetLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer) Entropy(edu.neu.ccs.pyramid.eval.Entropy) Vector(org.apache.mahout.math.Vector) LogManager(org.apache.logging.log4j.LogManager) LKBOutputCalculator(edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator) ArrayList(java.util.ArrayList) Classifier(edu.neu.ccs.pyramid.classification.Classifier) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)

Example 33 with Broadcast

use of org.apache.spark.broadcast.Broadcast in project beam by apache.

the class SparkBatchPortablePipelineTranslator method broadcastSideInputs.

/**
 * Broadcast the side inputs of an executable stage. *This can be expensive.*
 *
 * @return Map from PCollection ID to Spark broadcast variable and coder to decode its contents.
 */
private static <SideInputT> ImmutableMap<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>> broadcastSideInputs(RunnerApi.ExecutableStagePayload stagePayload, SparkTranslationContext context) {
    Map<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>> broadcastVariables = new HashMap<>();
    for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
        RunnerApi.Components stagePayloadComponents = stagePayload.getComponents();
        String collectionId = stagePayloadComponents.getTransformsOrThrow(sideInputId.getTransformId()).getInputsOrThrow(sideInputId.getLocalName());
        if (broadcastVariables.containsKey(collectionId)) {
            // This PCollection has already been broadcast.
            continue;
        }
        Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>> tuple2 = broadcastSideInput(collectionId, stagePayloadComponents, context);
        broadcastVariables.put(collectionId, tuple2);
    }
    return ImmutableMap.copyOf(broadcastVariables);
}
Also used : RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) WindowedValueCoder(org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder) PipelineTranslatorUtils.getWindowedValueCoder(org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.getWindowedValueCoder) HashMap(java.util.HashMap) Broadcast(org.apache.spark.broadcast.Broadcast) Tuple2(scala.Tuple2) Components(org.apache.beam.model.pipeline.v1.RunnerApi.Components) List(java.util.List) SideInputId(org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId)

Aggregations

Broadcast (org.apache.spark.broadcast.Broadcast)33 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)26 Collectors (java.util.stream.Collectors)25 List (java.util.List)20 JavaRDD (org.apache.spark.api.java.JavaRDD)17 Tuple2 (scala.Tuple2)16 IntervalUtils (org.broadinstitute.hellbender.utils.IntervalUtils)15 JavaPairRDD (org.apache.spark.api.java.JavaPairRDD)13 Argument (org.broadinstitute.barclay.argparser.Argument)13 SAMSequenceDictionary (htsjdk.samtools.SAMSequenceDictionary)11 IntStream (java.util.stream.IntStream)11 LogManager (org.apache.logging.log4j.LogManager)11 Logger (org.apache.logging.log4j.Logger)11 ReferenceMultiSource (org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource)11 SimpleInterval (org.broadinstitute.hellbender.utils.SimpleInterval)11 StreamSupport (java.util.stream.StreamSupport)10 FlatMapFunction (org.apache.spark.api.java.function.FlatMapFunction)10 org.broadinstitute.hellbender.engine (org.broadinstitute.hellbender.engine)10 UserException (org.broadinstitute.hellbender.exceptions.UserException)10 GATKException (org.broadinstitute.hellbender.exceptions.GATKException)9