use of org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload in project beam by apache.
the class SingleEnvironmentInstanceJobBundleFactoryTest method closeShutsDownEnvironmentsWhenSomeFail.
@Test
public void closeShutsDownEnvironmentsWhenSomeFail() throws Exception {
Pipeline p = Pipeline.create();
ExperimentalOptions.addExperiment(p.getOptions().as(ExperimentalOptions.class), "beam_fn_api");
p.apply("Create", Create.of(1, 2, 3));
ExecutableStage firstEnvStage = GreedyPipelineFuser.fuse(PipelineTranslation.toProto(p)).getFusedStages().stream().findFirst().get();
ExecutableStagePayload basePayload = ExecutableStagePayload.parseFrom(firstEnvStage.toPTransform("foo").getSpec().getPayload());
Environment secondEnv = Environments.createDockerEnvironment("second_env");
ExecutableStage secondEnvStage = ExecutableStage.fromPayload(basePayload.toBuilder().setEnvironment(secondEnv).build());
Environment thirdEnv = Environments.createDockerEnvironment("third_env");
ExecutableStage thirdEnvStage = ExecutableStage.fromPayload(basePayload.toBuilder().setEnvironment(thirdEnv).build());
RemoteEnvironment firstRemoteEnv = mock(RemoteEnvironment.class, "First Remote Env");
RemoteEnvironment secondRemoteEnv = mock(RemoteEnvironment.class, "Second Remote Env");
RemoteEnvironment thirdRemoteEnv = mock(RemoteEnvironment.class, "Third Remote Env");
when(environmentFactory.createEnvironment(firstEnvStage.getEnvironment(), GENERATED_ID)).thenReturn(firstRemoteEnv);
when(environmentFactory.createEnvironment(secondEnvStage.getEnvironment(), GENERATED_ID)).thenReturn(secondRemoteEnv);
when(environmentFactory.createEnvironment(thirdEnvStage.getEnvironment(), GENERATED_ID)).thenReturn(thirdRemoteEnv);
when(firstRemoteEnv.getInstructionRequestHandler()).thenReturn(instructionRequestHandler);
when(secondRemoteEnv.getInstructionRequestHandler()).thenReturn(instructionRequestHandler);
when(thirdRemoteEnv.getInstructionRequestHandler()).thenReturn(instructionRequestHandler);
factory.forStage(firstEnvStage);
factory.forStage(secondEnvStage);
factory.forStage(thirdEnvStage);
IllegalStateException firstException = new IllegalStateException("first stage");
doThrow(firstException).when(firstRemoteEnv).close();
IllegalStateException thirdException = new IllegalStateException("third stage");
doThrow(thirdException).when(thirdRemoteEnv).close();
try {
factory.close();
fail("Factory close should have thrown");
} catch (IllegalStateException expected) {
if (expected.equals(firstException)) {
assertThat(ImmutableList.copyOf(expected.getSuppressed()), contains(thirdException));
} else if (expected.equals(thirdException)) {
assertThat(ImmutableList.copyOf(expected.getSuppressed()), contains(firstException));
} else {
throw expected;
}
verify(firstRemoteEnv).close();
verify(secondRemoteEnv).close();
verify(thirdRemoteEnv).close();
}
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload in project beam by apache.
the class ExecutableStage method fromPayload.
/**
* Return an {@link ExecutableStage} constructed from the provided {@link FunctionSpec}
* representation.
*
* <p>See {@link #toPTransform} for how the payload is constructed.
*
* <p>Note: The payload contains some information redundant with the {@link PTransform} it is the
* payload of. The {@link ExecutableStagePayload} should be sufficiently rich to construct a
* {@code ProcessBundleDescriptor} using only the payload.
*/
static ExecutableStage fromPayload(ExecutableStagePayload payload) {
Components components = payload.getComponents();
Environment environment = payload.getEnvironment();
Collection<WireCoderSetting> wireCoderSettings = payload.getWireCoderSettingsList();
PCollectionNode input = PipelineNode.pCollection(payload.getInput(), components.getPcollectionsOrThrow(payload.getInput()));
List<SideInputReference> sideInputs = payload.getSideInputsList().stream().map(sideInputId -> SideInputReference.fromSideInputId(sideInputId, components)).collect(Collectors.toList());
List<UserStateReference> userStates = payload.getUserStatesList().stream().map(userStateId -> UserStateReference.fromUserStateId(userStateId, components)).collect(Collectors.toList());
List<TimerReference> timers = payload.getTimersList().stream().map(timerId -> TimerReference.fromTimerId(timerId, components)).collect(Collectors.toList());
List<PTransformNode> transforms = payload.getTransformsList().stream().map(id -> PipelineNode.pTransform(id, components.getTransformsOrThrow(id))).collect(Collectors.toList());
List<PCollectionNode> outputs = payload.getOutputsList().stream().map(id -> PipelineNode.pCollection(id, components.getPcollectionsOrThrow(id))).collect(Collectors.toList());
return ImmutableExecutableStage.of(components, environment, input, sideInputs, userStates, timers, transforms, outputs, wireCoderSettings);
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload in project beam by apache.
the class PipelineValidator method validateExecutableStage.
private static void validateExecutableStage(String id, PTransform transform, Components outerComponents, Set<String> requirements) throws Exception {
ExecutableStagePayload payload = ExecutableStagePayload.parseFrom(transform.getSpec().getPayload());
// Everything within an ExecutableStagePayload uses only the stage's components.
Components components = payload.getComponents();
checkArgument(transform.getInputsMap().values().contains(payload.getInput()), "ExecutableStage %s uses unknown input %s", id, payload.getInput());
checkArgument(!payload.getTransformsList().isEmpty(), "ExecutableStage %s contains no transforms", id);
for (String subtransformId : payload.getTransformsList()) {
checkArgument(components.containsTransforms(subtransformId), "ExecutableStage %s uses unknown transform %s", id, subtransformId);
}
for (String outputId : payload.getOutputsList()) {
checkArgument(components.containsPcollections(outputId), "ExecutableStage %s uses unknown output %s", id, outputId);
}
validateComponents("ExecutableStage " + id, components, requirements);
// TODO: Also validate that side inputs of all transforms within components.getTransforms()
// are contained within payload.getSideInputsList()
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload in project beam by apache.
the class ExecutableStage method toPTransform.
/**
* Returns a composite {@link PTransform} which is equivalent to this {@link ExecutableStage} as
* follows:
*
* <ul>
* <li>The {@link PTransform#getSubtransformsList()} is empty. This ensures that executable
* stages are treated as primitive transforms.
* <li>The only {@link PCollection PCollections} in the {@link PTransform#getInputsMap()} is the
* result of {@link #getInputPCollection()} and {@link #getSideInputs()}.
* <li>The output {@link PCollection PCollections} in the values of {@link
* PTransform#getOutputsMap()} are the {@link PCollectionNode PCollections} returned by
* {@link #getOutputPCollections()}.
* <li>The {@link PTransform#getSpec()} contains an {@link ExecutableStagePayload} with inputs
* and outputs equal to the PTransform's inputs and outputs, and transforms equal to the
* result of {@link #getTransforms}.
* </ul>
*
* <p>The executable stage can be reconstructed from the resulting {@link ExecutableStagePayload}
* via {@link #fromPayload(ExecutableStagePayload)}.
*/
default PTransform toPTransform(String uniqueName) {
PTransform.Builder pt = PTransform.newBuilder().setUniqueName(uniqueName);
ExecutableStagePayload.Builder payload = ExecutableStagePayload.newBuilder();
payload.setEnvironment(getEnvironment());
payload.addAllWireCoderSettings(getWireCoderSettings());
// Populate inputs and outputs of the stage payload and outer PTransform simultaneously.
PCollectionNode input = getInputPCollection();
pt.putInputs("input", getInputPCollection().getId());
payload.setInput(input.getId());
for (SideInputReference sideInput : getSideInputs()) {
// Side inputs of the ExecutableStage itself can be uniquely identified by inner PTransform
// name and local name.
String outerLocalName = String.format("%s:%s", sideInput.transform().getId(), sideInput.localName());
pt.putInputs(outerLocalName, sideInput.collection().getId());
payload.addSideInputs(SideInputId.newBuilder().setTransformId(sideInput.transform().getId()).setLocalName(sideInput.localName()));
}
for (UserStateReference userState : getUserStates()) {
payload.addUserStates(UserStateId.newBuilder().setTransformId(userState.transform().getId()).setLocalName(userState.localName()));
}
for (TimerReference timer : getTimers()) {
payload.addTimers(TimerId.newBuilder().setTransformId(timer.transform().getId()).setLocalName(timer.localName()));
}
int outputIndex = 0;
for (PCollectionNode output : getOutputPCollections()) {
pt.putOutputs(String.format("materialized_%d", outputIndex), output.getId());
payload.addOutputs(output.getId());
outputIndex++;
}
// stage payload.
for (PTransformNode transform : getTransforms()) {
payload.addTransforms(transform.getId());
}
payload.setComponents(getComponents().toBuilder().clearTransforms().putAllTransforms(getTransforms().stream().collect(Collectors.toMap(PTransformNode::getId, PTransformNode::getTransform))));
pt.setSpec(FunctionSpec.newBuilder().setUrn(ExecutableStage.URN).setPayload(payload.build().toByteString()).build());
return pt.build();
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload in project beam by apache.
the class ExecutableStageDoFnOperatorTest method getOperator.
@SuppressWarnings("rawtypes")
private ExecutableStageDoFnOperator getOperator(TupleTag<Integer> mainOutput, List<TupleTag<?>> additionalOutputs, DoFnOperator.MultiOutputOutputManagerFactory<Integer> outputManagerFactory, WindowingStrategy windowingStrategy, @Nullable Coder keyCoder, Coder windowedInputCoder) {
FlinkExecutableStageContextFactory contextFactory = Mockito.mock(FlinkExecutableStageContextFactory.class);
when(contextFactory.get(any())).thenReturn(stageContext);
final ExecutableStagePayload stagePayload;
if (keyCoder != null) {
stagePayload = this.stagePayloadWithUserState;
} else {
stagePayload = this.stagePayload;
}
ExecutableStageDoFnOperator<Integer, Integer> operator = new ExecutableStageDoFnOperator<>("transform", windowedInputCoder, Collections.emptyMap(), mainOutput, additionalOutputs, outputManagerFactory, Collections.emptyMap(), /* sideInputTagMapping */
Collections.emptyList(), /* sideInputs */
Collections.emptyMap(), /* sideInputId mapping */
FlinkPipelineOptions.defaults(), stagePayload, jobInfo, contextFactory, createOutputMap(mainOutput, additionalOutputs), windowingStrategy, keyCoder, keyCoder != null ? new KvToByteBufferKeySelector<>(keyCoder, null) : null);
Whitebox.setInternalState(operator, "stateRequestHandler", stateRequestHandler);
return operator;
}
Aggregations