use of org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode in project beam by apache.
the class GreedyPipelineFuser method getDescendantConsumers.
/**
* Retrieve all descendant {@link PTransformNode PTransforms} which are executed within an {@link
* Environment}, such that there is a path between this input {@link PCollectionNode} and the
* descendant {@link PTransformNode} with no intermediate {@link PTransformNode} which executes
* within an environment.
*
* <p>This occurs as follows:
*
* <ul>
* <li>For each consumer of the input {@link PCollectionNode}:
* <ul>
* <li>If that {@link PTransformNode} executes within an environment, add it to the
* collection of descendants
* <li>If that {@link PTransformNode} does not execute within an environment, for each
* output {@link PCollectionNode} that that {@link PTransformNode} produces, add the
* result of recursively applying this method to that {@link PCollectionNode}.
* </ul>
* </ul>
*
* <p>As {@link PCollectionNode PCollections} output by a {@link PTransformNode} that executes
* within an {@link Environment} are not recursively inspected, {@link PTransformNode PTransforms}
* reachable only via a path including that node as an intermediate node cannot be returned as a
* descendant consumer of the original {@link PCollectionNode}.
*/
private DescendantConsumers getDescendantConsumers(PCollectionNode inputPCollection) {
Set<PTransformNode> unfused = new HashSet<>();
NavigableSet<CollectionConsumer> downstreamConsumers = new TreeSet<>();
for (PTransformNode consumer : pipeline.getPerElementConsumers(inputPCollection)) {
if (pipeline.getEnvironment(consumer).isPresent()) {
// The base case: this descendant consumes elements from
downstreamConsumers.add(CollectionConsumer.of(inputPCollection, consumer));
} else {
LOG.debug("Adding {} {} to the set of runner-executed transforms", PTransformNode.class.getSimpleName(), consumer.getId());
unfused.add(consumer);
for (PCollectionNode output : pipeline.getOutputPCollections(consumer)) {
// Recurse to all of the ouput PCollections of this PTransform.
DescendantConsumers descendants = getDescendantConsumers(output);
unfused.addAll(descendants.getUnfusedNodes());
downstreamConsumers.addAll(descendants.getFusibleConsumers());
}
}
}
return DescendantConsumers.of(unfused, downstreamConsumers);
}
use of org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode in project beam by apache.
the class GreedyPipelineFuser method getRootConsumers.
private DescendantConsumers getRootConsumers(PTransformNode rootNode) {
checkArgument(rootNode.getTransform().getInputsCount() == 0, "Transform %s is not at the root of the graph (consumes %s)", rootNode.getId(), rootNode.getTransform().getInputsMap());
checkArgument(!pipeline.getEnvironment(rootNode).isPresent(), "%s requires all root nodes to be runner-implemented %s or %s primitives, " + "but transform %s executes in environment %s", GreedyPipelineFuser.class.getSimpleName(), PTransformTranslation.IMPULSE_TRANSFORM_URN, PTransformTranslation.READ_TRANSFORM_URN, rootNode.getId(), pipeline.getEnvironment(rootNode));
Set<PTransformNode> unfused = new HashSet<>();
unfused.add(rootNode);
NavigableSet<CollectionConsumer> environmentNodes = new TreeSet<>();
// Walk down until the first environments are found, and fuse them as appropriate.
for (PCollectionNode output : pipeline.getOutputPCollections(rootNode)) {
DescendantConsumers descendants = getDescendantConsumers(output);
unfused.addAll(descendants.getUnfusedNodes());
environmentNodes.addAll(descendants.getFusibleConsumers());
}
return DescendantConsumers.of(unfused, environmentNodes);
}
use of org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode in project beam by apache.
the class GreedyStageFuser method anyInputsSideInputs.
private static boolean anyInputsSideInputs(PTransformNode consumer, QueryablePipeline pipeline) {
for (String inputPCollectionId : consumer.getTransform().getInputsMap().values()) {
RunnerApi.PCollection pCollection = pipeline.getComponents().getPcollectionsMap().get(inputPCollectionId);
PCollectionNode pCollectionNode = PipelineNode.pCollection(inputPCollectionId, pCollection);
if (!pipeline.getSingletonConsumers(pCollectionNode).isEmpty()) {
return true;
}
}
return false;
}
use of org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode in project beam by apache.
the class GreedyStageFuser method forGrpcPortRead.
/**
* Returns an {@link ExecutableStage} where the initial {@link PTransformNode PTransform} is a
* Remote gRPC Port Read, reading elements from the materialized {@link PCollectionNode
* PCollection}.
*
* @param initialNodes the initial set of sibling transforms to fuse into this node. All of the
* transforms must consume the {@code inputPCollection} on a per-element basis, and must all
* be mutually compatible.
*/
public static ExecutableStage forGrpcPortRead(QueryablePipeline pipeline, PCollectionNode inputPCollection, Set<PTransformNode> initialNodes) {
checkArgument(!initialNodes.isEmpty(), "%s must contain at least one %s.", GreedyStageFuser.class.getSimpleName(), PTransformNode.class.getSimpleName());
// Choose the environment from an arbitrary node. The initial nodes may not be empty for this
// subgraph to make any sense, there has to be at least one processor node
// (otherwise the stage is gRPC Read -> gRPC Write, which doesn't do anything).
Environment environment = getStageEnvironment(pipeline, initialNodes);
ImmutableSet.Builder<PTransformNode> fusedTransforms = ImmutableSet.builder();
fusedTransforms.addAll(initialNodes);
Set<SideInputReference> sideInputs = new LinkedHashSet<>();
Set<UserStateReference> userStates = new LinkedHashSet<>();
Set<TimerReference> timers = new LinkedHashSet<>();
Set<PCollectionNode> fusedCollections = new LinkedHashSet<>();
Set<PCollectionNode> materializedPCollections = new LinkedHashSet<>();
Queue<PCollectionNode> fusionCandidates = new ArrayDeque<>();
for (PTransformNode initialConsumer : initialNodes) {
fusionCandidates.addAll(pipeline.getOutputPCollections(initialConsumer));
sideInputs.addAll(pipeline.getSideInputs(initialConsumer));
userStates.addAll(pipeline.getUserStates(initialConsumer));
timers.addAll(pipeline.getTimers(initialConsumer));
}
while (!fusionCandidates.isEmpty()) {
PCollectionNode candidate = fusionCandidates.poll();
if (fusedCollections.contains(candidate) || materializedPCollections.contains(candidate)) {
// This should generally mean we get to a Flatten via multiple paths through the graph and
// we've already determined what to do with the output.
LOG.debug("Skipping fusion candidate {} because it is {} in this {}", candidate, fusedCollections.contains(candidate) ? "fused" : "materialized", ExecutableStage.class.getSimpleName());
continue;
}
PCollectionFusibility fusibility = canFuse(pipeline, candidate, environment, fusedCollections);
switch(fusibility) {
case MATERIALIZE:
materializedPCollections.add(candidate);
break;
case FUSE:
// All of the consumers of the candidate PCollection can be fused into this stage. Do so.
fusedCollections.add(candidate);
fusedTransforms.addAll(pipeline.getPerElementConsumers(candidate));
for (PTransformNode consumer : pipeline.getPerElementConsumers(candidate)) {
// The outputs of every transform fused into this stage must be either materialized or
// themselves fused away, so add them to the set of candidates.
fusionCandidates.addAll(pipeline.getOutputPCollections(consumer));
sideInputs.addAll(pipeline.getSideInputs(consumer));
}
break;
default:
throw new IllegalStateException(String.format("Unknown type of %s %s", PCollectionFusibility.class.getSimpleName(), fusibility));
}
}
return ImmutableExecutableStage.ofFullComponents(pipeline.getComponents(), environment, inputPCollection, sideInputs, userStates, timers, fusedTransforms.build(), materializedPCollections, DEFAULT_WIRE_CODER_SETTINGS);
}
use of org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode in project beam by apache.
the class GreedyPipelineFuser method sanitizeDanglingPTransformInputs.
private static ExecutableStage sanitizeDanglingPTransformInputs(ExecutableStage stage) {
/* Possible inputs to a PTransform can only be those which are:
* <ul>
* <li>Explicit input PCollection to the stage
* <li>Outputs of a PTransform within the same stage
* <li>Timer PCollections
* <li>Side input PCollections
* <li>Explicit outputs from the stage
* </ul>
*/
Set<String> possibleInputs = new HashSet<>();
possibleInputs.add(stage.getInputPCollection().getId());
possibleInputs.addAll(stage.getOutputPCollections().stream().map(PCollectionNode::getId).collect(Collectors.toSet()));
possibleInputs.addAll(stage.getSideInputs().stream().map(s -> s.collection().getId()).collect(Collectors.toSet()));
possibleInputs.addAll(stage.getTransforms().stream().flatMap(t -> t.getTransform().getOutputsMap().values().stream()).collect(Collectors.toSet()));
Set<String> danglingInputs = stage.getTransforms().stream().flatMap(t -> t.getTransform().getInputsMap().values().stream()).filter(in -> !possibleInputs.contains(in)).collect(Collectors.toSet());
ImmutableList.Builder<PTransformNode> pTransformNodesBuilder = ImmutableList.builder();
for (PTransformNode transformNode : stage.getTransforms()) {
PTransform transform = transformNode.getTransform();
Map<String, String> validInputs = transform.getInputsMap().entrySet().stream().filter(e -> !danglingInputs.contains(e.getValue())).collect(Collectors.toMap(Entry::getKey, Entry::getValue));
if (!validInputs.equals(transform.getInputsMap())) {
// Dangling inputs found so recreate pTransform without the dangling inputs.
transformNode = PipelineNode.pTransform(transformNode.getId(), transform.toBuilder().clearInputs().putAllInputs(validInputs).build());
}
pTransformNodesBuilder.add(transformNode);
}
ImmutableList<PTransformNode> pTransformNodes = pTransformNodesBuilder.build();
Components.Builder componentBuilder = stage.getComponents().toBuilder();
// Update the pTransforms in components.
componentBuilder.clearTransforms().putAllTransforms(pTransformNodes.stream().collect(Collectors.toMap(PTransformNode::getId, PTransformNode::getTransform)));
Map<String, PCollection> validPCollectionMap = stage.getComponents().getPcollectionsMap().entrySet().stream().filter(e -> !danglingInputs.contains(e.getKey())).collect(Collectors.toMap(Entry::getKey, Entry::getValue));
// Update pCollections in the components.
componentBuilder.clearPcollections().putAllPcollections(validPCollectionMap);
return ImmutableExecutableStage.of(componentBuilder.build(), stage.getEnvironment(), stage.getInputPCollection(), stage.getSideInputs(), stage.getUserStates(), stage.getTimers(), pTransformNodes, stage.getOutputPCollections(), stage.getWireCoderSettings());
}
Aggregations