use of org.locationtech.geowave.analytic.spark.GeoWaveIndexedRDD in project geowave by locationtech.
the class TieredSpatialJoin method join.
@Override
public void join(final SparkSession spark, final GeoWaveIndexedRDD leftRDD, final GeoWaveIndexedRDD rightRDD, final GeomFunction predicate) throws InterruptedException, ExecutionException {
// Get SparkContext from session
final SparkContext sc = spark.sparkContext();
final JavaSparkContext javaSC = JavaSparkContext.fromSparkContext(sc);
final NumericIndexStrategy leftStrategy = leftRDD.getIndexStrategy().getValue();
final NumericIndexStrategy rightStrategy = rightRDD.getIndexStrategy().getValue();
// Check if either dataset supports the join
TieredSFCIndexStrategy tieredStrategy = null;
// Determine if either strategy needs to be reindexed to support join algorithm
boolean reindexLeft = false;
boolean reindexRight = false;
final boolean leftSupport = supportsJoin(leftStrategy);
final boolean rightSupport = supportsJoin(rightStrategy);
if (leftSupport && rightSupport) {
if (leftStrategy.equals(rightStrategy)) {
// Both strategies match we don't have to reindex
tieredStrategy = (TieredSFCIndexStrategy) leftStrategy;
} else {
// support but don't match
if (getJoinOptions().getJoinBuildSide() == JoinOptions.BuildSide.LEFT) {
reindexRight = true;
tieredStrategy = (TieredSFCIndexStrategy) leftStrategy;
} else {
reindexLeft = true;
tieredStrategy = (TieredSFCIndexStrategy) rightStrategy;
}
}
} else if (leftSupport) {
reindexRight = true;
tieredStrategy = (TieredSFCIndexStrategy) leftStrategy;
} else if (rightSupport) {
reindexLeft = true;
tieredStrategy = (TieredSFCIndexStrategy) rightStrategy;
} else {
tieredStrategy = (TieredSFCIndexStrategy) createDefaultStrategy(leftStrategy);
if (tieredStrategy == null) {
tieredStrategy = (TieredSFCIndexStrategy) createDefaultStrategy(rightStrategy);
}
if (tieredStrategy == null) {
LOGGER.error("Cannot create default strategy from either provided strategy. Datasets cannot be joined.");
return;
}
reindexLeft = true;
reindexRight = true;
}
// Pull information and broadcast strategy used for join
final SubStrategy[] tierStrategies = tieredStrategy.getSubStrategies();
final int tierCount = tierStrategies.length;
// Create broadcast variable for indexing strategy
// Cast is safe because we must be instance of TieredSFCIndexStrategy to support join.
final Broadcast<TieredSFCIndexStrategy> broadcastStrategy = (Broadcast<TieredSFCIndexStrategy>) RDDUtils.broadcastIndexStrategy(sc, tieredStrategy);
final Broadcast<GeomFunction> geomPredicate = javaSC.broadcast(predicate);
// If needed reindex one of the strategies we will wrap the buffer operation into the reindex
// operation
// Otherwise we buffer based off the buildside of the join.
setBufferAmount(predicate.getBufferAmount());
// Reindex if necessary and get RDD of indexed Geometry
JavaPairRDD<ByteArray, Tuple2<GeoWaveInputKey, Geometry>> leftIndex = null;
JavaPairRDD<ByteArray, Tuple2<GeoWaveInputKey, Geometry>> rightIndex = null;
if (reindexLeft && reindexRight) {
leftRDD.reindex(broadcastStrategy);
rightRDD.reindex(broadcastStrategy);
} else if (reindexLeft) {
leftRDD.reindex(broadcastStrategy);
} else if (reindexRight) {
rightRDD.reindex(broadcastStrategy);
}
if (joinOpts.getJoinBuildSide() == BuildSide.LEFT) {
rightIndex = rightRDD.getIndexedGeometryRDD(bufferDistance, true);
leftIndex = leftRDD.getIndexedGeometryRDD();
} else {
leftIndex = leftRDD.getIndexedGeometryRDD(bufferDistance, true);
rightIndex = rightRDD.getIndexedGeometryRDD();
}
final int leftPartCount = leftIndex.getNumPartitions();
final int rightPartCount = rightIndex.getNumPartitions();
final int highestPartCount = (leftPartCount > rightPartCount) ? leftPartCount : rightPartCount;
final int largePartitionerCount = (int) (1.5 * highestPartCount);
final HashPartitioner partitioner = new HashPartitioner(largePartitionerCount);
final JavaFutureAction<List<Byte>> leftFuture = leftIndex.setName("LeftIndex").keys().map(t -> t.getBytes()[0]).distinct(4).collectAsync();
final JavaFutureAction<List<Byte>> rightFuture = rightIndex.setName("RightIndex").keys().map(t -> t.getBytes()[0]).distinct(4).collectAsync();
// Get the result of future
final List<Byte> rightDataTiers = Lists.newArrayList(rightFuture.get());
// Sort tiers highest to lowest and collect information.
final Byte[] rightTierArr = rightDataTiers.toArray(new Byte[0]);
Arrays.sort(rightTierArr);
final int rightTierCount = rightTierArr.length;
final List<Byte> leftDataTiers = Lists.newArrayList(leftFuture.get());
final Byte[] leftTierArr = leftDataTiers.toArray(new Byte[0]);
Arrays.sort(leftTierArr);
final int leftTierCount = leftTierArr.length;
// Determine if there are common higher tiers for whole dataset on either side.
final byte highestLeftTier = leftTierArr[leftTierArr.length - 1];
final byte highestRightTier = rightTierArr[rightTierArr.length - 1];
// Find a common run of higher tiers
Byte[] commonLeftTiers = ArrayUtils.EMPTY_BYTE_OBJECT_ARRAY;
Byte[] commonRightTiers = ArrayUtils.EMPTY_BYTE_OBJECT_ARRAY;
boolean skipMapCreate = false;
if (leftTierArr[0] > highestRightTier) {
// Whole left dataset is higher tiers than right
commonLeftTiers = leftTierArr;
skipMapCreate = true;
} else if (rightTierArr[0] > highestLeftTier) {
// Whole right dataset is higher tiers than left
commonRightTiers = rightTierArr;
skipMapCreate = true;
}
LOGGER.debug("Tier Count: " + tierCount);
LOGGER.debug("Left Tier Count: " + leftTierCount + " Right Tier Count: " + rightTierCount);
LOGGER.debug("Left Tiers: " + leftDataTiers);
LOGGER.debug("Right Tiers: " + rightDataTiers);
Map<Byte, HashSet<Byte>> rightReprojectMap = new HashMap<>();
Map<Byte, HashSet<Byte>> leftReprojectMap = new HashMap<>();
final HashSet<Byte> sharedTiers = Sets.newHashSetWithExpectedSize(tierCount / 2);
if (!skipMapCreate) {
leftReprojectMap = createReprojectMap(leftTierArr, rightTierArr, sharedTiers);
rightReprojectMap = createReprojectMap(rightTierArr, leftTierArr, sharedTiers);
}
JavaRDD<Tuple2<GeoWaveInputKey, Geometry>> commonRightRDD = null;
final boolean commonRightExist = commonRightTiers != ArrayUtils.EMPTY_BYTE_OBJECT_ARRAY;
if (commonRightExist) {
commonRightRDD = rightRDD.getGeoWaveRDD().getRawRDD().filter(t -> t._2.getDefaultGeometry() != null).mapValues((Function<SimpleFeature, Geometry>) t -> {
return (Geometry) t.getDefaultGeometry();
}).distinct(largePartitionerCount).rdd().toJavaRDD();
}
JavaRDD<Tuple2<GeoWaveInputKey, Geometry>> commonLeftRDD = null;
final boolean commonLeftExist = commonLeftTiers != ArrayUtils.EMPTY_BYTE_OBJECT_ARRAY;
if (commonLeftExist) {
commonLeftRDD = leftRDD.getGeoWaveRDD().getRawRDD().filter(t -> t._2.getDefaultGeometry() != null).mapValues((Function<SimpleFeature, Geometry>) t -> {
return (Geometry) t.getDefaultGeometry();
}).distinct(largePartitionerCount).rdd().toJavaRDD();
}
// Iterate through left tiers. Joining higher right and same level tiers
for (final Byte leftTierId : leftDataTiers) {
final HashSet<Byte> higherRightTiers = leftReprojectMap.get(leftTierId);
JavaPairRDD<ByteArray, Tuple2<GeoWaveInputKey, Geometry>> leftTier = null;
final boolean higherTiersExist = ((higherRightTiers != null) && !higherRightTiers.isEmpty());
final boolean sameTierExist = sharedTiers.contains(leftTierId);
if (commonRightExist || higherTiersExist || sameTierExist) {
leftTier = filterTier(leftIndex, leftTierId);
} else {
// No tiers to compare against this tier
continue;
}
// Check for same tier existence on both sides and join without reprojection.
if (sameTierExist) {
final JavaPairRDD<ByteArray, Tuple2<GeoWaveInputKey, Geometry>> rightTier = rightIndex.filter(t -> t._1().getBytes()[0] == leftTierId);
final JavaPairRDD<GeoWaveInputKey, ByteArray> finalMatches = joinAndCompareTiers(leftTier, rightTier, geomPredicate, highestPartCount, partitioner);
addMatches(finalMatches);
}
// Join against higher common tiers for this dataset
JavaRDD<Tuple2<GeoWaveInputKey, Geometry>> rightTiers = null;
if (commonRightExist) {
rightTiers = commonRightRDD;
} else if (higherTiersExist) {
final Broadcast<HashSet<Byte>> higherBroadcast = javaSC.broadcast(higherRightTiers);
rightTiers = prepareForReproject(rightIndex.filter(t -> higherBroadcast.value().contains(t._1().getBytes()[0])), largePartitionerCount);
}
if (rightTiers != null) {
final JavaPairRDD<ByteArray, Tuple2<GeoWaveInputKey, Geometry>> reprojected = reprojectToTier(rightTiers, leftTierId, broadcastStrategy, getBufferAmount(BuildSide.RIGHT), partitioner);
final JavaPairRDD<GeoWaveInputKey, ByteArray> finalMatches = joinAndCompareTiers(leftTier, reprojected, geomPredicate, highestPartCount, partitioner);
addMatches(finalMatches);
}
}
for (final Byte rightTierId : rightDataTiers) {
final HashSet<Byte> higherLeftTiers = rightReprojectMap.get(rightTierId);
JavaPairRDD<ByteArray, Tuple2<GeoWaveInputKey, Geometry>> rightTier = null;
final boolean higherLeftExist = ((higherLeftTiers != null) && !higherLeftTiers.isEmpty());
if (commonLeftExist || higherLeftExist) {
rightTier = rightIndex.filter(t -> t._1().getBytes()[0] == rightTierId);
} else {
// No tiers to compare against this tier
continue;
}
JavaPairRDD<GeoWaveInputKey, ByteArray> finalMatches = null;
JavaRDD<Tuple2<GeoWaveInputKey, Geometry>> leftTiers = null;
if (commonLeftExist) {
leftTiers = commonLeftRDD;
} else {
final Broadcast<HashSet<Byte>> higherBroadcast = javaSC.broadcast(higherLeftTiers);
leftTiers = prepareForReproject(leftIndex.filter(t -> higherBroadcast.value().contains(t._1.getBytes()[0])), largePartitionerCount);
}
final JavaPairRDD<ByteArray, Tuple2<GeoWaveInputKey, Geometry>> reprojected = reprojectToTier(leftTiers, rightTierId, broadcastStrategy, getBufferAmount(BuildSide.LEFT), partitioner);
finalMatches = joinAndCompareTiers(reprojected, rightTier, geomPredicate, highestPartCount, partitioner);
addMatches(finalMatches);
}
// Remove duplicates between tiers
combinedResults = javaSC.union((JavaPairRDD[]) (ArrayUtils.add(tierMatches.toArray(new JavaPairRDD[tierMatches.size()]), combinedResults)));
combinedResults = combinedResults.reduceByKey((id1, id2) -> id1);
combinedResults = combinedResults.setName("CombinedJoinResults").persist(StorageLevel.MEMORY_ONLY_SER());
// Force evaluation of RDD at the join function call.
// Otherwise it doesn't actually perform work until something is called
// on left/right joined.
// Wish there was a better way to force evaluation of rdd safely.
// isEmpty() triggers take(1) which shouldn't involve a shuffle.
combinedResults.isEmpty();
// don't recalculate
if (getJoinOptions().isNegativePredicate()) {
setLeftResults(new GeoWaveRDD(leftRDD.getGeoWaveRDD().getRawRDD().subtractByKey(combinedResults).cache()));
setRightResults(new GeoWaveRDD(rightRDD.getGeoWaveRDD().getRawRDD().subtractByKey(combinedResults).cache()));
} else {
setLeftResults(new GeoWaveRDD(leftRDD.getGeoWaveRDD().getRawRDD().join(combinedResults).mapToPair(t -> new Tuple2<>(t._1(), t._2._1())).cache()));
setRightResults(new GeoWaveRDD(rightRDD.getGeoWaveRDD().getRawRDD().join(combinedResults).mapToPair(t -> new Tuple2<>(t._1(), t._2._1())).cache()));
}
leftIndex.unpersist();
rightIndex.unpersist();
}
Aggregations