Search in sources :

Example 1 with JavaFutureAction

use of org.apache.spark.api.java.JavaFutureAction in project geowave by locationtech.

the class TieredSpatialJoin method join.

@Override
public void join(final SparkSession spark, final GeoWaveIndexedRDD leftRDD, final GeoWaveIndexedRDD rightRDD, final GeomFunction predicate) throws InterruptedException, ExecutionException {
    // Get SparkContext from session
    final SparkContext sc = spark.sparkContext();
    final JavaSparkContext javaSC = JavaSparkContext.fromSparkContext(sc);
    final NumericIndexStrategy leftStrategy = leftRDD.getIndexStrategy().getValue();
    final NumericIndexStrategy rightStrategy = rightRDD.getIndexStrategy().getValue();
    // Check if either dataset supports the join
    TieredSFCIndexStrategy tieredStrategy = null;
    // Determine if either strategy needs to be reindexed to support join algorithm
    boolean reindexLeft = false;
    boolean reindexRight = false;
    final boolean leftSupport = supportsJoin(leftStrategy);
    final boolean rightSupport = supportsJoin(rightStrategy);
    if (leftSupport && rightSupport) {
        if (leftStrategy.equals(rightStrategy)) {
            // Both strategies match we don't have to reindex
            tieredStrategy = (TieredSFCIndexStrategy) leftStrategy;
        } else {
            // support but don't match
            if (getJoinOptions().getJoinBuildSide() == JoinOptions.BuildSide.LEFT) {
                reindexRight = true;
                tieredStrategy = (TieredSFCIndexStrategy) leftStrategy;
            } else {
                reindexLeft = true;
                tieredStrategy = (TieredSFCIndexStrategy) rightStrategy;
            }
        }
    } else if (leftSupport) {
        reindexRight = true;
        tieredStrategy = (TieredSFCIndexStrategy) leftStrategy;
    } else if (rightSupport) {
        reindexLeft = true;
        tieredStrategy = (TieredSFCIndexStrategy) rightStrategy;
    } else {
        tieredStrategy = (TieredSFCIndexStrategy) createDefaultStrategy(leftStrategy);
        if (tieredStrategy == null) {
            tieredStrategy = (TieredSFCIndexStrategy) createDefaultStrategy(rightStrategy);
        }
        if (tieredStrategy == null) {
            LOGGER.error("Cannot create default strategy from either provided strategy. Datasets cannot be joined.");
            return;
        }
        reindexLeft = true;
        reindexRight = true;
    }
    // Pull information and broadcast strategy used for join
    final SubStrategy[] tierStrategies = tieredStrategy.getSubStrategies();
    final int tierCount = tierStrategies.length;
    // Create broadcast variable for indexing strategy
    // Cast is safe because we must be instance of TieredSFCIndexStrategy to support join.
    final Broadcast<TieredSFCIndexStrategy> broadcastStrategy = (Broadcast<TieredSFCIndexStrategy>) RDDUtils.broadcastIndexStrategy(sc, tieredStrategy);
    final Broadcast<GeomFunction> geomPredicate = javaSC.broadcast(predicate);
    // If needed reindex one of the strategies we will wrap the buffer operation into the reindex
    // operation
    // Otherwise we buffer based off the buildside of the join.
    setBufferAmount(predicate.getBufferAmount());
    // Reindex if necessary and get RDD of indexed Geometry
    JavaPairRDD<ByteArray, Tuple2<GeoWaveInputKey, Geometry>> leftIndex = null;
    JavaPairRDD<ByteArray, Tuple2<GeoWaveInputKey, Geometry>> rightIndex = null;
    if (reindexLeft && reindexRight) {
        leftRDD.reindex(broadcastStrategy);
        rightRDD.reindex(broadcastStrategy);
    } else if (reindexLeft) {
        leftRDD.reindex(broadcastStrategy);
    } else if (reindexRight) {
        rightRDD.reindex(broadcastStrategy);
    }
    if (joinOpts.getJoinBuildSide() == BuildSide.LEFT) {
        rightIndex = rightRDD.getIndexedGeometryRDD(bufferDistance, true);
        leftIndex = leftRDD.getIndexedGeometryRDD();
    } else {
        leftIndex = leftRDD.getIndexedGeometryRDD(bufferDistance, true);
        rightIndex = rightRDD.getIndexedGeometryRDD();
    }
    final int leftPartCount = leftIndex.getNumPartitions();
    final int rightPartCount = rightIndex.getNumPartitions();
    final int highestPartCount = (leftPartCount > rightPartCount) ? leftPartCount : rightPartCount;
    final int largePartitionerCount = (int) (1.5 * highestPartCount);
    final HashPartitioner partitioner = new HashPartitioner(largePartitionerCount);
    final JavaFutureAction<List<Byte>> leftFuture = leftIndex.setName("LeftIndex").keys().map(t -> t.getBytes()[0]).distinct(4).collectAsync();
    final JavaFutureAction<List<Byte>> rightFuture = rightIndex.setName("RightIndex").keys().map(t -> t.getBytes()[0]).distinct(4).collectAsync();
    // Get the result of future
    final List<Byte> rightDataTiers = Lists.newArrayList(rightFuture.get());
    // Sort tiers highest to lowest and collect information.
    final Byte[] rightTierArr = rightDataTiers.toArray(new Byte[0]);
    Arrays.sort(rightTierArr);
    final int rightTierCount = rightTierArr.length;
    final List<Byte> leftDataTiers = Lists.newArrayList(leftFuture.get());
    final Byte[] leftTierArr = leftDataTiers.toArray(new Byte[0]);
    Arrays.sort(leftTierArr);
    final int leftTierCount = leftTierArr.length;
    // Determine if there are common higher tiers for whole dataset on either side.
    final byte highestLeftTier = leftTierArr[leftTierArr.length - 1];
    final byte highestRightTier = rightTierArr[rightTierArr.length - 1];
    // Find a common run of higher tiers
    Byte[] commonLeftTiers = ArrayUtils.EMPTY_BYTE_OBJECT_ARRAY;
    Byte[] commonRightTiers = ArrayUtils.EMPTY_BYTE_OBJECT_ARRAY;
    boolean skipMapCreate = false;
    if (leftTierArr[0] > highestRightTier) {
        // Whole left dataset is higher tiers than right
        commonLeftTiers = leftTierArr;
        skipMapCreate = true;
    } else if (rightTierArr[0] > highestLeftTier) {
        // Whole right dataset is higher tiers than left
        commonRightTiers = rightTierArr;
        skipMapCreate = true;
    }
    LOGGER.debug("Tier Count: " + tierCount);
    LOGGER.debug("Left Tier Count: " + leftTierCount + " Right Tier Count: " + rightTierCount);
    LOGGER.debug("Left Tiers: " + leftDataTiers);
    LOGGER.debug("Right Tiers: " + rightDataTiers);
    Map<Byte, HashSet<Byte>> rightReprojectMap = new HashMap<>();
    Map<Byte, HashSet<Byte>> leftReprojectMap = new HashMap<>();
    final HashSet<Byte> sharedTiers = Sets.newHashSetWithExpectedSize(tierCount / 2);
    if (!skipMapCreate) {
        leftReprojectMap = createReprojectMap(leftTierArr, rightTierArr, sharedTiers);
        rightReprojectMap = createReprojectMap(rightTierArr, leftTierArr, sharedTiers);
    }
    JavaRDD<Tuple2<GeoWaveInputKey, Geometry>> commonRightRDD = null;
    final boolean commonRightExist = commonRightTiers != ArrayUtils.EMPTY_BYTE_OBJECT_ARRAY;
    if (commonRightExist) {
        commonRightRDD = rightRDD.getGeoWaveRDD().getRawRDD().filter(t -> t._2.getDefaultGeometry() != null).mapValues((Function<SimpleFeature, Geometry>) t -> {
            return (Geometry) t.getDefaultGeometry();
        }).distinct(largePartitionerCount).rdd().toJavaRDD();
    }
    JavaRDD<Tuple2<GeoWaveInputKey, Geometry>> commonLeftRDD = null;
    final boolean commonLeftExist = commonLeftTiers != ArrayUtils.EMPTY_BYTE_OBJECT_ARRAY;
    if (commonLeftExist) {
        commonLeftRDD = leftRDD.getGeoWaveRDD().getRawRDD().filter(t -> t._2.getDefaultGeometry() != null).mapValues((Function<SimpleFeature, Geometry>) t -> {
            return (Geometry) t.getDefaultGeometry();
        }).distinct(largePartitionerCount).rdd().toJavaRDD();
    }
    // Iterate through left tiers. Joining higher right and same level tiers
    for (final Byte leftTierId : leftDataTiers) {
        final HashSet<Byte> higherRightTiers = leftReprojectMap.get(leftTierId);
        JavaPairRDD<ByteArray, Tuple2<GeoWaveInputKey, Geometry>> leftTier = null;
        final boolean higherTiersExist = ((higherRightTiers != null) && !higherRightTiers.isEmpty());
        final boolean sameTierExist = sharedTiers.contains(leftTierId);
        if (commonRightExist || higherTiersExist || sameTierExist) {
            leftTier = filterTier(leftIndex, leftTierId);
        } else {
            // No tiers to compare against this tier
            continue;
        }
        // Check for same tier existence on both sides and join without reprojection.
        if (sameTierExist) {
            final JavaPairRDD<ByteArray, Tuple2<GeoWaveInputKey, Geometry>> rightTier = rightIndex.filter(t -> t._1().getBytes()[0] == leftTierId);
            final JavaPairRDD<GeoWaveInputKey, ByteArray> finalMatches = joinAndCompareTiers(leftTier, rightTier, geomPredicate, highestPartCount, partitioner);
            addMatches(finalMatches);
        }
        // Join against higher common tiers for this dataset
        JavaRDD<Tuple2<GeoWaveInputKey, Geometry>> rightTiers = null;
        if (commonRightExist) {
            rightTiers = commonRightRDD;
        } else if (higherTiersExist) {
            final Broadcast<HashSet<Byte>> higherBroadcast = javaSC.broadcast(higherRightTiers);
            rightTiers = prepareForReproject(rightIndex.filter(t -> higherBroadcast.value().contains(t._1().getBytes()[0])), largePartitionerCount);
        }
        if (rightTiers != null) {
            final JavaPairRDD<ByteArray, Tuple2<GeoWaveInputKey, Geometry>> reprojected = reprojectToTier(rightTiers, leftTierId, broadcastStrategy, getBufferAmount(BuildSide.RIGHT), partitioner);
            final JavaPairRDD<GeoWaveInputKey, ByteArray> finalMatches = joinAndCompareTiers(leftTier, reprojected, geomPredicate, highestPartCount, partitioner);
            addMatches(finalMatches);
        }
    }
    for (final Byte rightTierId : rightDataTiers) {
        final HashSet<Byte> higherLeftTiers = rightReprojectMap.get(rightTierId);
        JavaPairRDD<ByteArray, Tuple2<GeoWaveInputKey, Geometry>> rightTier = null;
        final boolean higherLeftExist = ((higherLeftTiers != null) && !higherLeftTiers.isEmpty());
        if (commonLeftExist || higherLeftExist) {
            rightTier = rightIndex.filter(t -> t._1().getBytes()[0] == rightTierId);
        } else {
            // No tiers to compare against this tier
            continue;
        }
        JavaPairRDD<GeoWaveInputKey, ByteArray> finalMatches = null;
        JavaRDD<Tuple2<GeoWaveInputKey, Geometry>> leftTiers = null;
        if (commonLeftExist) {
            leftTiers = commonLeftRDD;
        } else {
            final Broadcast<HashSet<Byte>> higherBroadcast = javaSC.broadcast(higherLeftTiers);
            leftTiers = prepareForReproject(leftIndex.filter(t -> higherBroadcast.value().contains(t._1.getBytes()[0])), largePartitionerCount);
        }
        final JavaPairRDD<ByteArray, Tuple2<GeoWaveInputKey, Geometry>> reprojected = reprojectToTier(leftTiers, rightTierId, broadcastStrategy, getBufferAmount(BuildSide.LEFT), partitioner);
        finalMatches = joinAndCompareTiers(reprojected, rightTier, geomPredicate, highestPartCount, partitioner);
        addMatches(finalMatches);
    }
    // Remove duplicates between tiers
    combinedResults = javaSC.union((JavaPairRDD[]) (ArrayUtils.add(tierMatches.toArray(new JavaPairRDD[tierMatches.size()]), combinedResults)));
    combinedResults = combinedResults.reduceByKey((id1, id2) -> id1);
    combinedResults = combinedResults.setName("CombinedJoinResults").persist(StorageLevel.MEMORY_ONLY_SER());
    // Force evaluation of RDD at the join function call.
    // Otherwise it doesn't actually perform work until something is called
    // on left/right joined.
    // Wish there was a better way to force evaluation of rdd safely.
    // isEmpty() triggers take(1) which shouldn't involve a shuffle.
    combinedResults.isEmpty();
    // don't recalculate
    if (getJoinOptions().isNegativePredicate()) {
        setLeftResults(new GeoWaveRDD(leftRDD.getGeoWaveRDD().getRawRDD().subtractByKey(combinedResults).cache()));
        setRightResults(new GeoWaveRDD(rightRDD.getGeoWaveRDD().getRawRDD().subtractByKey(combinedResults).cache()));
    } else {
        setLeftResults(new GeoWaveRDD(leftRDD.getGeoWaveRDD().getRawRDD().join(combinedResults).mapToPair(t -> new Tuple2<>(t._1(), t._2._1())).cache()));
        setRightResults(new GeoWaveRDD(rightRDD.getGeoWaveRDD().getRawRDD().join(combinedResults).mapToPair(t -> new Tuple2<>(t._1(), t._2._1())).cache()));
    }
    leftIndex.unpersist();
    rightIndex.unpersist();
}
Also used : ByteArray(org.locationtech.geowave.core.index.ByteArray) Arrays(java.util.Arrays) GeoWaveInputKey(org.locationtech.geowave.mapreduce.input.GeoWaveInputKey) GeoWaveRDD(org.locationtech.geowave.analytic.spark.GeoWaveRDD) PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) LoggerFactory(org.slf4j.LoggerFactory) HashMap(java.util.HashMap) SpatialDimensionalityTypeProvider(org.locationtech.geowave.core.geotime.index.SpatialDimensionalityTypeProvider) SubStrategy(org.locationtech.geowave.core.index.HierarchicalNumericIndexStrategy.SubStrategy) GeomFunction(org.locationtech.geowave.analytic.spark.sparksql.udf.GeomFunction) HashSet(java.util.HashSet) SpatialTemporalOptions(org.locationtech.geowave.core.geotime.index.SpatialTemporalOptions) Lists(com.google.common.collect.Lists) StorageLevel(org.apache.spark.storage.StorageLevel) SimpleFeature(org.opengis.feature.simple.SimpleFeature) Map(java.util.Map) Maps(jersey.repackaged.com.google.common.collect.Maps) JavaRDD(org.apache.spark.api.java.JavaRDD) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) SparkSession(org.apache.spark.sql.SparkSession) Broadcast(org.apache.spark.broadcast.Broadcast) RDDUtils(org.locationtech.geowave.analytic.spark.RDDUtils) Logger(org.slf4j.Logger) HashPartitioner(org.apache.spark.HashPartitioner) SparkContext(org.apache.spark.SparkContext) TieredSFCIndexStrategy(org.locationtech.geowave.core.index.sfc.tiered.TieredSFCIndexStrategy) TieredSFCIndexFactory(org.locationtech.geowave.core.index.sfc.tiered.TieredSFCIndexFactory) GeometryUtils(org.locationtech.geowave.core.geotime.util.GeometryUtils) Tuple2(scala.Tuple2) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) JavaFutureAction(org.apache.spark.api.java.JavaFutureAction) Sets(com.google.common.collect.Sets) ExecutionException(java.util.concurrent.ExecutionException) List(java.util.List) InsertionIds(org.locationtech.geowave.core.index.InsertionIds) SingleTierSubStrategy(org.locationtech.geowave.core.index.sfc.tiered.SingleTierSubStrategy) SFCType(org.locationtech.geowave.core.index.sfc.SFCFactory.SFCType) Geometry(org.locationtech.jts.geom.Geometry) BuildSide(org.locationtech.geowave.analytic.spark.spatial.JoinOptions.BuildSide) SpatialTemporalDimensionalityTypeProvider(org.locationtech.geowave.core.geotime.index.SpatialTemporalDimensionalityTypeProvider) Function(org.apache.spark.api.java.function.Function) GeoWaveIndexedRDD(org.locationtech.geowave.analytic.spark.GeoWaveIndexedRDD) MultiDimensionalNumericData(org.locationtech.geowave.core.index.numeric.MultiDimensionalNumericData) NumericIndexStrategy(org.locationtech.geowave.core.index.NumericIndexStrategy) Envelope(org.locationtech.jts.geom.Envelope) ArrayUtils(org.apache.commons.lang.ArrayUtils) TieredSFCIndexStrategy(org.locationtech.geowave.core.index.sfc.tiered.TieredSFCIndexStrategy) HashMap(java.util.HashMap) Broadcast(org.apache.spark.broadcast.Broadcast) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ByteArray(org.locationtech.geowave.core.index.ByteArray) List(java.util.List) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) HashSet(java.util.HashSet) SubStrategy(org.locationtech.geowave.core.index.HierarchicalNumericIndexStrategy.SubStrategy) SingleTierSubStrategy(org.locationtech.geowave.core.index.sfc.tiered.SingleTierSubStrategy) GeoWaveInputKey(org.locationtech.geowave.mapreduce.input.GeoWaveInputKey) SimpleFeature(org.opengis.feature.simple.SimpleFeature) GeomFunction(org.locationtech.geowave.analytic.spark.sparksql.udf.GeomFunction) Geometry(org.locationtech.jts.geom.Geometry) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) SparkContext(org.apache.spark.SparkContext) Tuple2(scala.Tuple2) HashPartitioner(org.apache.spark.HashPartitioner) GeoWaveRDD(org.locationtech.geowave.analytic.spark.GeoWaveRDD) NumericIndexStrategy(org.locationtech.geowave.core.index.NumericIndexStrategy)

Aggregations

Lists (com.google.common.collect.Lists)1 Sets (com.google.common.collect.Sets)1 Arrays (java.util.Arrays)1 HashMap (java.util.HashMap)1 HashSet (java.util.HashSet)1 List (java.util.List)1 Map (java.util.Map)1 ExecutionException (java.util.concurrent.ExecutionException)1 Maps (jersey.repackaged.com.google.common.collect.Maps)1 ArrayUtils (org.apache.commons.lang.ArrayUtils)1 HashPartitioner (org.apache.spark.HashPartitioner)1 SparkContext (org.apache.spark.SparkContext)1 JavaFutureAction (org.apache.spark.api.java.JavaFutureAction)1 JavaPairRDD (org.apache.spark.api.java.JavaPairRDD)1 JavaRDD (org.apache.spark.api.java.JavaRDD)1 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)1 FlatMapFunction (org.apache.spark.api.java.function.FlatMapFunction)1 Function (org.apache.spark.api.java.function.Function)1 PairFlatMapFunction (org.apache.spark.api.java.function.PairFlatMapFunction)1 Broadcast (org.apache.spark.broadcast.Broadcast)1