use of org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleRequest.CacheToken in project beam by apache.
the class ProcessBundleHandlerTest method testBundleProcessorReset.
@Test
public void testBundleProcessorReset() throws Exception {
PTransformFunctionRegistry startFunctionRegistry = mock(PTransformFunctionRegistry.class);
PTransformFunctionRegistry finishFunctionRegistry = mock(PTransformFunctionRegistry.class);
BundleSplitListener.InMemory splitListener = mock(BundleSplitListener.InMemory.class);
Collection<CallbackRegistration> bundleFinalizationCallbacks = mock(Collection.class);
PCollectionConsumerRegistry pCollectionConsumerRegistry = mock(PCollectionConsumerRegistry.class);
MetricsContainerStepMap metricsContainerRegistry = mock(MetricsContainerStepMap.class);
ExecutionStateTracker stateTracker = mock(ExecutionStateTracker.class);
ProcessBundleHandler.HandleStateCallsForBundle beamFnStateClient = mock(ProcessBundleHandler.HandleStateCallsForBundle.class);
ThrowingRunnable resetFunction = mock(ThrowingRunnable.class);
Cache<Object, Object> processWideCache = Caches.eternal();
BundleProcessor bundleProcessor = BundleProcessor.create(processWideCache, ProcessBundleDescriptor.getDefaultInstance(), startFunctionRegistry, finishFunctionRegistry, Collections.singletonList(resetFunction), new ArrayList<>(), new ArrayList<>(), splitListener, pCollectionConsumerRegistry, metricsContainerRegistry, stateTracker, beamFnStateClient, bundleFinalizationCallbacks, new HashSet<>());
bundleProcessor.finish();
CacheToken cacheToken = CacheToken.newBuilder().setSideInput(CacheToken.SideInput.newBuilder().setTransformId("transformId")).build();
bundleProcessor.setupForProcessBundleRequest(processBundleRequestFor("instructionId", "descriptorId", cacheToken));
assertEquals("instructionId", bundleProcessor.getInstructionId());
assertThat(bundleProcessor.getCacheTokens(), containsInAnyOrder(cacheToken));
Cache<Object, Object> bundleCache = bundleProcessor.getBundleCache();
bundleCache.put("A", "B");
assertEquals("B", bundleCache.peek("A"));
bundleProcessor.reset();
assertNull(bundleProcessor.getInstructionId());
assertNull(bundleProcessor.getCacheTokens());
assertNull(bundleCache.peek("A"));
verify(startFunctionRegistry, times(1)).reset();
verify(finishFunctionRegistry, times(1)).reset();
verify(splitListener, times(1)).clear();
verify(pCollectionConsumerRegistry, times(1)).reset();
verify(metricsContainerRegistry, times(1)).reset();
verify(stateTracker, times(1)).reset();
verify(bundleFinalizationCallbacks, times(1)).clear();
verify(resetFunction, times(1)).run();
// Ensure that the next setup produces the expected state.
bundleProcessor.setupForProcessBundleRequest(processBundleRequestFor("instructionId2", "descriptorId2"));
assertNotSame(bundleCache, bundleProcessor.getBundleCache());
assertEquals("instructionId2", bundleProcessor.getInstructionId());
assertThat(bundleProcessor.getCacheTokens(), is(emptyIterable()));
}
Aggregations