use of org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Struct in project beam by apache.
the class DefaultJobBundleFactoryTest method createsMultipleEnvironmentsWithSdkWorkerParallelism.
@Test
public void createsMultipleEnvironmentsWithSdkWorkerParallelism() throws Exception {
ServerFactory serverFactory = ServerFactory.createDefault();
Environment environmentA = Environment.newBuilder().setUrn("env:urn:a").setPayload(ByteString.copyFrom(new byte[1])).build();
EnvironmentFactory envFactoryA = mock(EnvironmentFactory.class);
when(envFactoryA.createEnvironment(eq(environmentA), any())).thenReturn(remoteEnvironment);
EnvironmentFactory.Provider environmentProviderFactoryA = mock(EnvironmentFactory.Provider.class);
when(environmentProviderFactoryA.createEnvironmentFactory(any(), any(), any(), any(), any(), any())).thenReturn(envFactoryA);
when(environmentProviderFactoryA.getServerFactory()).thenReturn(serverFactory);
Map<String, Provider> environmentFactoryProviderMap = ImmutableMap.of(environmentA.getUrn(), environmentProviderFactoryA);
PortablePipelineOptions portableOptions = PipelineOptionsFactory.as(PortablePipelineOptions.class);
portableOptions.setSdkWorkerParallelism(2);
Struct pipelineOptions = PipelineOptionsTranslation.toProto(portableOptions);
try (DefaultJobBundleFactory bundleFactory = new DefaultJobBundleFactory(JobInfo.create("testJob", "testJob", "token", pipelineOptions), environmentFactoryProviderMap, stageIdGenerator, serverInfo)) {
bundleFactory.forStage(getExecutableStage(environmentA));
verify(environmentProviderFactoryA, Mockito.times(1)).createEnvironmentFactory(any(), any(), any(), any(), any(), any());
verify(envFactoryA, Mockito.times(1)).createEnvironment(eq(environmentA), any());
bundleFactory.forStage(getExecutableStage(environmentA));
verify(environmentProviderFactoryA, Mockito.times(2)).createEnvironmentFactory(any(), any(), any(), any(), any(), any());
verify(envFactoryA, Mockito.times(2)).createEnvironment(eq(environmentA), any());
// round robin, no new environment created
bundleFactory.forStage(getExecutableStage(environmentA));
verify(environmentProviderFactoryA, Mockito.times(2)).createEnvironmentFactory(any(), any(), any(), any(), any(), any());
verify(envFactoryA, Mockito.times(2)).createEnvironment(eq(environmentA), any());
}
portableOptions.setSdkWorkerParallelism(0);
pipelineOptions = PipelineOptionsTranslation.toProto(portableOptions);
Mockito.reset(envFactoryA);
when(envFactoryA.createEnvironment(eq(environmentA), any())).thenReturn(remoteEnvironment);
int expectedParallelism = Math.max(1, Runtime.getRuntime().availableProcessors() - 1);
try (DefaultJobBundleFactory bundleFactory = new DefaultJobBundleFactory(JobInfo.create("testJob", "testJob", "token", pipelineOptions), environmentFactoryProviderMap, stageIdGenerator, serverInfo)) {
HashSet<StageBundleFactory> stageBundleFactorySet = new HashSet<>();
// more factories than parallelism for round-robin
int numStageBundleFactories = expectedParallelism + 5;
for (int i = 0; i < numStageBundleFactories; i++) {
stageBundleFactorySet.add(bundleFactory.forStage(getExecutableStage(environmentA)));
}
verify(envFactoryA, Mockito.times(expectedParallelism)).createEnvironment(eq(environmentA), any());
Assert.assertEquals(numStageBundleFactories, stageBundleFactorySet.size());
}
}
use of org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Struct in project beam by apache.
the class DefaultJobBundleFactoryTest method rejectsStateCachingWithLoadBalancing.
@Test
public void rejectsStateCachingWithLoadBalancing() throws Exception {
PortablePipelineOptions portableOptions = PipelineOptionsFactory.as(PortablePipelineOptions.class);
portableOptions.setLoadBalanceBundles(true);
ExperimentalOptions options = portableOptions.as(ExperimentalOptions.class);
ExperimentalOptions.addExperiment(options, "state_cache_size=1");
Struct pipelineOptions = PipelineOptionsTranslation.toProto(options);
Exception e = Assert.assertThrows(IllegalArgumentException.class, () -> new DefaultJobBundleFactory(JobInfo.create("testJob", "testJob", "token", pipelineOptions), envFactoryProviderMap, stageIdGenerator, serverInfo).close());
assertThat(e.getMessage(), containsString("state_cache_size"));
}
use of org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Struct in project beam by apache.
the class SparkPipelineRunner method main.
/**
* Main method to be called only as the entry point to an executable jar with structure as defined
* in {@link PortablePipelineJarUtils}.
*/
public static void main(String[] args) throws Exception {
// Register standard file systems.
FileSystems.setDefaultPipelineOptions(PipelineOptionsFactory.create());
SparkPipelineRunnerConfiguration configuration = parseArgs(args);
String baseJobName = configuration.baseJobName == null ? PortablePipelineJarUtils.getDefaultJobName() : configuration.baseJobName;
Preconditions.checkArgument(baseJobName != null, "No default job name found. Job name must be set using --base-job-name.");
Pipeline pipeline = PortablePipelineJarUtils.getPipelineFromClasspath(baseJobName);
Struct originalOptions = PortablePipelineJarUtils.getPipelineOptionsFromClasspath(baseJobName);
// The retrieval token is only required by the legacy artifact service, which the Spark runner
// no longer uses.
String retrievalToken = ArtifactApi.CommitManifestResponse.Constants.NO_ARTIFACTS_STAGED_TOKEN.getValueDescriptor().getOptions().getExtension(RunnerApi.beamConstant);
SparkPipelineOptions sparkOptions = PipelineOptionsTranslation.fromProto(originalOptions).as(SparkPipelineOptions.class);
String invocationId = String.format("%s_%s", sparkOptions.getJobName(), UUID.randomUUID().toString());
if (sparkOptions.getAppName() == null) {
LOG.debug("App name was null. Using invocationId {}", invocationId);
sparkOptions.setAppName(invocationId);
}
SparkPipelineRunner runner = new SparkPipelineRunner(sparkOptions);
JobInfo jobInfo = JobInfo.create(invocationId, sparkOptions.getJobName(), retrievalToken, PipelineOptionsTranslation.toProto(sparkOptions));
try {
runner.run(pipeline, jobInfo);
} catch (Exception e) {
throw new RuntimeException(String.format("Job %s failed.", invocationId), e);
}
LOG.info("Job {} finished successfully.", invocationId);
}
Aggregations