use of io.cdap.cdap.etl.api.join.JoinCondition in project cdap by caskdata.
the class SparkPipelineRunner method handleAutoJoinOnKeys.
/**
* The purpose of this method is to collect various pieces of information together into a JoinRequest.
* This amounts to gathering the SparkCollection, schema, join key, and join type for each stage involved in the join.
*/
private SparkCollection<Object> handleAutoJoinOnKeys(String stageName, JoinDefinition joinDefinition, Map<String, SparkCollection<Object>> inputDataCollections, @Nullable Integer numPartitions) {
// sort stages to join so that broadcasts happen last. This is to ensure that the left side is not a broadcast
// so that we don't try to broadcast both sides of the join. It also causes less data to be shuffled for the
// non-broadcast joins.
List<JoinStage> joinOrder = new ArrayList<>(joinDefinition.getStages());
joinOrder.sort((s1, s2) -> {
if (s1.isBroadcast() && !s2.isBroadcast()) {
return 1;
} else if (!s1.isBroadcast() && s2.isBroadcast()) {
return -1;
}
return 0;
});
Iterator<JoinStage> stageIter = joinOrder.iterator();
JoinStage left = stageIter.next();
String leftName = left.getStageName();
SparkCollection<Object> leftCollection = inputDataCollections.get(left.getStageName());
Schema leftSchema = left.getSchema();
JoinCondition condition = joinDefinition.getCondition();
JoinCondition.OnKeys onKeys = (JoinCondition.OnKeys) condition;
// If this is a join on A.x = B.y = C.z and A.k = B.k = C.k, then stageKeys will look like:
// A -> [x, k]
// B -> [y, k]
// C -> [z, k]
Map<String, List<String>> stageKeys = onKeys.getKeys().stream().collect(Collectors.toMap(JoinKey::getStageName, JoinKey::getFields));
Schema outputSchema = joinDefinition.getOutputSchema();
// when we properly propagate schema at runtime, this condition should no longer happen
if (outputSchema == null) {
throw new IllegalArgumentException(String.format("Joiner stage '%s' cannot calculate its output schema because " + "one or more inputs have dynamic or unknown schema. " + "An output schema must be directly provided.", stageName));
}
List<JoinCollection> toJoin = new ArrayList<>();
List<Schema> keySchema = null;
while (stageIter.hasNext()) {
// in this loop, information for each stage to be joined is gathered together into a JoinCollection
JoinStage right = stageIter.next();
String rightName = right.getStageName();
Schema rightSchema = right.getSchema();
List<String> key = stageKeys.get(rightName);
if (rightSchema == null) {
if (keySchema == null) {
keySchema = deriveKeySchema(stageName, stageKeys, joinDefinition);
}
// if the schema is not known, generate it from the provided output schema and the selected fields
rightSchema = deriveInputSchema(stageName, rightName, key, keySchema, joinDefinition.getSelectedFields(), joinDefinition.getOutputSchema());
} else {
// drop fields that aren't included in the final output
// don't need to do this if the schema was derived, since it will already only contain
// fields in the output schema
Set<String> requiredFields = new HashSet<>(key);
for (JoinField joinField : joinDefinition.getSelectedFields()) {
if (!joinField.getStageName().equals(rightName)) {
continue;
}
requiredFields.add(joinField.getFieldName());
}
rightSchema = trimSchema(rightSchema, requiredFields);
}
// JoinCollection contains the stage name, SparkCollection, schema, joinkey,
// whether it's required, and whether to broadcast
toJoin.add(new JoinCollection(rightName, inputDataCollections.get(rightName), rightSchema, key, right.isRequired(), right.isBroadcast()));
}
List<String> leftKey = stageKeys.get(leftName);
if (leftSchema == null) {
if (keySchema == null) {
keySchema = deriveKeySchema(stageName, stageKeys, joinDefinition);
}
leftSchema = deriveInputSchema(stageName, leftName, leftKey, keySchema, joinDefinition.getSelectedFields(), joinDefinition.getOutputSchema());
}
// JoinRequest contains the left side of the join, plus 1 or more other stages to join to.
JoinRequest joinRequest = new JoinRequest(stageName, leftName, leftKey, leftSchema, left.isRequired(), onKeys.isNullSafe(), joinDefinition.getSelectedFields(), joinDefinition.getOutputSchema(), toJoin, numPartitions, joinDefinition.getDistribution(), joinDefinition);
return leftCollection.join(joinRequest);
}
use of io.cdap.cdap.etl.api.join.JoinCondition in project cdap by caskdata.
the class JoinOnFunction method createInitializedJoinOnTransform.
private JoinOnTransform<INPUT_RECORD, JOIN_KEY> createInitializedJoinOnTransform() throws Exception {
Object plugin = pluginFunctionContext.createPlugin();
BatchJoiner<JOIN_KEY, INPUT_RECORD, Object> joiner;
boolean filterNullKeys = false;
if (plugin instanceof BatchAutoJoiner) {
BatchAutoJoiner autoJoiner = (BatchAutoJoiner) plugin;
AutoJoinerContext autoJoinerContext = pluginFunctionContext.createAutoJoinerContext();
JoinDefinition joinDefinition = autoJoiner.define(autoJoinerContext);
autoJoinerContext.getFailureCollector().getOrThrowException();
String stageName = pluginFunctionContext.getStageName();
if (joinDefinition == null) {
throw new IllegalStateException(String.format("Join stage '%s' did not specify a join definition. " + "Check with the plugin developer to ensure it is implemented correctly.", stageName));
}
JoinCondition condition = joinDefinition.getCondition();
/*
Filter out the record if it comes from an optional stage
and the key is null, or if any of the fields in the key is null.
For example, suppose we are performing a left outer join on:
A (id, name) = (0, alice), (null, bob)
B (id, email) = (0, alice@example.com), (null, placeholder@example.com)
The final output should be:
joined (A.id, A.name, B.email) = (0, alice, alice@example.com), (null, bob, null, null)
that is, the bob record should not be joined to the placeholder@example email, even though both their
ids are null.
*/
if (condition.getOp() == JoinCondition.Op.KEY_EQUALITY && !((JoinCondition.OnKeys) condition).isNullSafe()) {
filterNullKeys = joinDefinition.getStages().stream().filter(s -> !s.isRequired()).map(JoinStage::getStageName).anyMatch(s -> s.equals(inputStageName));
}
joiner = new JoinerBridge(stageName, autoJoiner, joinDefinition);
} else {
joiner = (BatchJoiner<JOIN_KEY, INPUT_RECORD, Object>) plugin;
BatchJoinerRuntimeContext context = pluginFunctionContext.createBatchRuntimeContext();
joiner.initialize(context);
}
return new JoinOnTransform<>(joiner, inputStageName, filterNullKeys);
}
use of io.cdap.cdap.etl.api.join.JoinCondition in project cdap by caskdata.
the class MapReduceTransformExecutorFactory method getTransformation.
@SuppressWarnings("unchecked")
@Override
protected <IN, OUT> TrackedTransform<IN, OUT> getTransformation(StageSpec stageSpec) throws Exception {
String stageName = stageSpec.getName();
String pluginType = stageSpec.getPluginType();
StageMetrics stageMetrics = new DefaultStageMetrics(metrics, stageName);
TaskAttemptContext taskAttemptContext = (TaskAttemptContext) taskContext.getHadoopContext();
StageStatisticsCollector collector = collectStageStatistics ? new MapReduceStageStatisticsCollector(stageName, taskAttemptContext) : new NoopStageStatisticsCollector();
if (BatchAggregator.PLUGIN_TYPE.equals(pluginType)) {
Object plugin = pluginInstantiator.newPluginInstance(stageName, macroEvaluator);
BatchAggregator<?, ?, ?> batchAggregator;
if (plugin instanceof BatchReducibleAggregator) {
BatchReducibleAggregator<?, ?, ?, ?> reducibleAggregator = (BatchReducibleAggregator<?, ?, ?, ?>) plugin;
batchAggregator = new AggregatorBridge<>(reducibleAggregator);
} else {
batchAggregator = (BatchAggregator<?, ?, ?>) plugin;
}
BatchRuntimeContext runtimeContext = createRuntimeContext(stageSpec);
batchAggregator.initialize(runtimeContext);
if (isMapPhase) {
return getTrackedEmitKeyStep(new MapperAggregatorTransformation(batchAggregator, mapOutputKeyClassName, mapOutputValClassName), stageMetrics, getDataTracer(stageName), collector);
} else {
return getTrackedAggregateStep(new ReducerAggregatorTransformation(batchAggregator, mapOutputKeyClassName, mapOutputValClassName), stageMetrics, getDataTracer(stageName), collector);
}
} else if (BatchJoiner.PLUGIN_TYPE.equals(pluginType)) {
Object plugin = pluginInstantiator.newPluginInstance(stageName, macroEvaluator);
BatchJoiner<?, ?, ?> batchJoiner;
Set<String> filterNullKeyStages = new HashSet<>();
if (plugin instanceof BatchAutoJoiner) {
BatchAutoJoiner autoJoiner = (BatchAutoJoiner) plugin;
FailureCollector failureCollector = new LoggingFailureCollector(stageName, stageSpec.getInputSchemas());
DefaultAutoJoinerContext context = DefaultAutoJoinerContext.from(stageSpec.getInputSchemas(), failureCollector);
// definition will be non-null due to validate by PipelinePhasePreparer at the start of the run
JoinDefinition joinDefinition = autoJoiner.define(context);
JoinCondition condition = joinDefinition.getCondition();
// should never happen as it's checked at deployment time, but add this to be safe.
if (condition.getOp() != JoinCondition.Op.KEY_EQUALITY) {
failureCollector.addFailure(String.format("Join stage '%s' uses a %s condition, which is not supported with the MapReduce engine.", stageName, condition.getOp()), "Switch to a different execution engine.");
}
failureCollector.getOrThrowException();
batchJoiner = new JoinerBridge(stageName, autoJoiner, joinDefinition);
// this is the same as filtering out records that have a null key if they are from an optional stage
if (condition.getOp() == JoinCondition.Op.KEY_EQUALITY && !((JoinCondition.OnKeys) condition).isNullSafe()) {
filterNullKeyStages = joinDefinition.getStages().stream().filter(s -> !s.isRequired()).map(JoinStage::getStageName).collect(Collectors.toSet());
}
} else {
batchJoiner = (BatchJoiner<?, ?, ?>) plugin;
}
BatchJoinerRuntimeContext runtimeContext = createRuntimeContext(stageSpec);
batchJoiner.initialize(runtimeContext);
if (isMapPhase) {
return getTrackedEmitKeyStep(new MapperJoinerTransformation(batchJoiner, mapOutputKeyClassName, mapOutputValClassName, filterNullKeyStages), stageMetrics, getDataTracer(stageName), collector);
} else {
return getTrackedMergeStep(new ReducerJoinerTransformation(batchJoiner, mapOutputKeyClassName, mapOutputValClassName, runtimeContext.getInputSchemas().size()), stageMetrics, getDataTracer(stageName), collector);
}
}
return super.getTransformation(stageSpec);
}
use of io.cdap.cdap.etl.api.join.JoinCondition in project cdap by caskdata.
the class MockAutoJoiner method define.
@Nullable
@Override
public JoinDefinition define(AutoJoinerContext context) {
if (conf.containsMacro(Conf.STAGES) || conf.containsMacro(Conf.KEY) || conf.containsMacro(Conf.REQUIRED) || conf.containsMacro(Conf.SELECT)) {
return null;
}
Map<String, JoinStage> inputStages = context.getInputStages();
List<JoinStage> from = new ArrayList<>(inputStages.size());
Set<String> required = new HashSet<>(conf.getRequired());
Set<String> broadcast = new HashSet<>(conf.getBroadcast());
List<JoinField> selectedFields = conf.getSelect();
boolean shouldGenerateSelected = selectedFields.isEmpty();
JoinCondition condition = conf.getJoinConditionExpr();
JoinCondition.OnKeys.Builder conditionBuilder = condition != null ? null : JoinCondition.onKeys().setNullSafe(conf.isNullSafe());
for (String stageName : conf.getStages()) {
JoinStage.Builder stageBuilder = JoinStage.builder(inputStages.get(stageName));
if (!required.contains(stageName)) {
stageBuilder.isOptional();
}
if (broadcast.contains(stageName)) {
stageBuilder.setBroadcast(true);
}
JoinStage stage = stageBuilder.build();
from.add(stage);
if (conditionBuilder != null) {
conditionBuilder.addKey(new JoinKey(stageName, conf.getKey()));
}
Schema stageSchema = stage.getSchema();
if (!shouldGenerateSelected || stageSchema == null) {
continue;
}
for (Schema.Field field : stageSchema.getFields()) {
// alias everything to stage_field
selectedFields.add(new JoinField(stageName, field.getName(), String.format("%s_%s", stageName, field.getName())));
}
}
condition = condition == null ? conditionBuilder.build() : condition;
JoinDefinition.Builder builder = JoinDefinition.builder().select(selectedFields).on(condition).from(from).setOutputSchemaName(String.join(".", conf.getStages()));
Schema outputSchema = conf.getSchema();
if (outputSchema != null) {
builder.setOutputSchema(outputSchema);
}
if (conf.getDistributionName() != null && conf.getDistributionSize() != null) {
builder.setDistributionFactor(conf.getDistributionSize(), conf.getDistributionName());
}
return builder.build();
}
Aggregations