use of org.apache.flink.runtime.state.OperatorStateHandle in project flink by apache.
the class AbstractStreamOperatorTestHarness method initializeState.
/**
* Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorStateHandles)}.
* Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask, StreamConfig, Output)}
* if it was not called before.
*
* <p>This will reshape the state handles to include only those key-group states
* in the local key-group range and the operator states that would be assigned to the local
* subtask.
*/
public void initializeState(OperatorStateHandles operatorStateHandles) throws Exception {
if (!setupCalled) {
setup();
}
if (operatorStateHandles != null) {
int numKeyGroups = getEnvironment().getTaskInfo().getMaxNumberOfParallelSubtasks();
int numSubtasks = getEnvironment().getTaskInfo().getNumberOfParallelSubtasks();
int subtaskIndex = getEnvironment().getTaskInfo().getIndexOfThisSubtask();
// create a new OperatorStateHandles that only contains the state for our key-groups
List<KeyGroupRange> keyGroupPartitions = StateAssignmentOperation.createKeyGroupPartitions(numKeyGroups, numSubtasks);
KeyGroupRange localKeyGroupRange = keyGroupPartitions.get(subtaskIndex);
List<KeyGroupsStateHandle> localManagedKeyGroupState = null;
if (operatorStateHandles.getManagedKeyedState() != null) {
localManagedKeyGroupState = StateAssignmentOperation.getKeyGroupsStateHandles(operatorStateHandles.getManagedKeyedState(), localKeyGroupRange);
}
List<KeyGroupsStateHandle> localRawKeyGroupState = null;
if (operatorStateHandles.getRawKeyedState() != null) {
localRawKeyGroupState = StateAssignmentOperation.getKeyGroupsStateHandles(operatorStateHandles.getRawKeyedState(), localKeyGroupRange);
}
List<OperatorStateHandle> managedOperatorState = new ArrayList<>();
if (operatorStateHandles.getManagedOperatorState() != null) {
managedOperatorState.addAll(operatorStateHandles.getManagedOperatorState());
}
Collection<OperatorStateHandle> localManagedOperatorState = operatorStateRepartitioner.repartitionState(managedOperatorState, numSubtasks).get(subtaskIndex);
List<OperatorStateHandle> rawOperatorState = new ArrayList<>();
if (operatorStateHandles.getRawOperatorState() != null) {
rawOperatorState.addAll(operatorStateHandles.getRawOperatorState());
}
Collection<OperatorStateHandle> localRawOperatorState = operatorStateRepartitioner.repartitionState(rawOperatorState, numSubtasks).get(subtaskIndex);
OperatorStateHandles massagedOperatorStateHandles = new OperatorStateHandles(0, operatorStateHandles.getLegacyOperatorState(), localManagedKeyGroupState, localRawKeyGroupState, localManagedOperatorState, localRawOperatorState);
operator.initializeState(massagedOperatorStateHandles);
} else {
operator.initializeState(null);
}
initializeCalled = true;
}
use of org.apache.flink.runtime.state.OperatorStateHandle in project flink by apache.
the class AbstractStreamOperatorTestHarness method snapshot.
/**
* Calls {@link StreamOperator#snapshotState(long, long, CheckpointOptions)}.
*/
public OperatorStateHandles snapshot(long checkpointId, long timestamp) throws Exception {
CheckpointStreamFactory streamFactory = stateBackend.createStreamFactory(new JobID(), "test_op");
OperatorSnapshotResult operatorStateResult = operator.snapshotState(checkpointId, timestamp, CheckpointOptions.forFullCheckpoint());
KeyGroupsStateHandle keyedManaged = FutureUtil.runIfNotDoneAndGet(operatorStateResult.getKeyedStateManagedFuture());
KeyGroupsStateHandle keyedRaw = FutureUtil.runIfNotDoneAndGet(operatorStateResult.getKeyedStateRawFuture());
OperatorStateHandle opManaged = FutureUtil.runIfNotDoneAndGet(operatorStateResult.getOperatorStateManagedFuture());
OperatorStateHandle opRaw = FutureUtil.runIfNotDoneAndGet(operatorStateResult.getOperatorStateRawFuture());
return new OperatorStateHandles(0, null, keyedManaged != null ? Collections.singletonList(keyedManaged) : null, keyedRaw != null ? Collections.singletonList(keyedRaw) : null, opManaged != null ? Collections.singletonList(opManaged) : null, opRaw != null ? Collections.singletonList(opRaw) : null);
}
use of org.apache.flink.runtime.state.OperatorStateHandle in project flink by apache.
the class RoundRobinOperatorStateRepartitioner method groupByStateName.
/**
* Group by the different named states.
*/
@SuppressWarnings("unchecked, rawtype")
private GroupByStateNameResults groupByStateName(List<OperatorStateHandle> previousParallelSubtaskStates) {
//Reorganize: group by (State Name -> StreamStateHandle + StateMetaInfo)
EnumMap<OperatorStateHandle.Mode, Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> nameToStateByMode = new EnumMap<>(OperatorStateHandle.Mode.class);
for (OperatorStateHandle.Mode mode : OperatorStateHandle.Mode.values()) {
Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> map = new HashMap<>();
nameToStateByMode.put(mode, new HashMap<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>());
}
for (OperatorStateHandle psh : previousParallelSubtaskStates) {
for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> e : psh.getStateNameToPartitionOffsets().entrySet()) {
OperatorStateHandle.StateMetaInfo metaInfo = e.getValue();
Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> nameToState = nameToStateByMode.get(metaInfo.getDistributionMode());
List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> stateLocations = nameToState.get(e.getKey());
if (stateLocations == null) {
stateLocations = new ArrayList<>();
nameToState.put(e.getKey(), stateLocations);
}
stateLocations.add(new Tuple2<>(psh.getDelegateStateHandle(), e.getValue()));
}
}
return new GroupByStateNameResults(nameToStateByMode);
}
use of org.apache.flink.runtime.state.OperatorStateHandle in project flink by apache.
the class CheckpointStateRestoreTest method testSetState.
/**
* Tests that on restore the task state is reset for each stateful task.
*/
@Test
public void testSetState() {
try {
final ChainedStateHandle<StreamStateHandle> serializedState = CheckpointCoordinatorTest.generateChainedStateHandle(new SerializableObject());
KeyGroupRange keyGroupRange = KeyGroupRange.of(0, 0);
List<SerializableObject> testStates = Collections.singletonList(new SerializableObject());
final KeyGroupsStateHandle serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
final JobID jid = new JobID();
final JobVertexID statefulId = new JobVertexID();
final JobVertexID statelessId = new JobVertexID();
Execution statefulExec1 = mockExecution();
Execution statefulExec2 = mockExecution();
Execution statefulExec3 = mockExecution();
Execution statelessExec1 = mockExecution();
Execution statelessExec2 = mockExecution();
ExecutionVertex stateful1 = mockExecutionVertex(statefulExec1, statefulId, 0, 3);
ExecutionVertex stateful2 = mockExecutionVertex(statefulExec2, statefulId, 1, 3);
ExecutionVertex stateful3 = mockExecutionVertex(statefulExec3, statefulId, 2, 3);
ExecutionVertex stateless1 = mockExecutionVertex(statelessExec1, statelessId, 0, 2);
ExecutionVertex stateless2 = mockExecutionVertex(statelessExec2, statelessId, 1, 2);
ExecutionJobVertex stateful = mockExecutionJobVertex(statefulId, new ExecutionVertex[] { stateful1, stateful2, stateful3 });
ExecutionJobVertex stateless = mockExecutionJobVertex(statelessId, new ExecutionVertex[] { stateless1, stateless2 });
Map<JobVertexID, ExecutionJobVertex> map = new HashMap<JobVertexID, ExecutionJobVertex>();
map.put(statefulId, stateful);
map.put(statelessId, stateless);
CheckpointCoordinator coord = new CheckpointCoordinator(jid, 200000L, 200000L, 0, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), new ExecutionVertex[] { stateful1, stateful2, stateful3, stateless1, stateless2 }, new ExecutionVertex[] { stateful1, stateful2, stateful3, stateless1, stateless2 }, new ExecutionVertex[0], new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, Executors.directExecutor());
// create ourselves a checkpoint with state
final long timestamp = 34623786L;
coord.triggerCheckpoint(timestamp, false);
PendingCheckpoint pending = coord.getPendingCheckpoints().values().iterator().next();
final long checkpointId = pending.getCheckpointId();
SubtaskState checkpointStateHandles = new SubtaskState(serializedState, null, null, serializedKeyGroupStates, null);
coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles));
coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles));
coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles));
coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId));
coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec2.getAttemptId(), checkpointId));
assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints());
assertEquals(0, coord.getNumberOfPendingCheckpoints());
// let the coordinator inject the state
coord.restoreLatestCheckpointedState(map, true, false);
// verify that each stateful vertex got the state
final TaskStateHandles taskStateHandles = new TaskStateHandles(serializedState, Collections.<Collection<OperatorStateHandle>>singletonList(null), Collections.<Collection<OperatorStateHandle>>singletonList(null), Collections.singletonList(serializedKeyGroupStates), null);
BaseMatcher<TaskStateHandles> matcher = new BaseMatcher<TaskStateHandles>() {
@Override
public boolean matches(Object o) {
if (o instanceof TaskStateHandles) {
return o.equals(taskStateHandles);
}
return false;
}
@Override
public void describeTo(Description description) {
description.appendValue(taskStateHandles);
}
};
verify(statefulExec1, times(1)).setInitialState(Mockito.argThat(matcher));
verify(statefulExec2, times(1)).setInitialState(Mockito.argThat(matcher));
verify(statefulExec3, times(1)).setInitialState(Mockito.argThat(matcher));
verify(statelessExec1, times(0)).setInitialState(Mockito.<TaskStateHandles>any());
verify(statelessExec2, times(0)).setInitialState(Mockito.<TaskStateHandles>any());
} catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
use of org.apache.flink.runtime.state.OperatorStateHandle in project flink by apache.
the class CheckpointCoordinatorTest method testReplicateModeStateHandle.
@Test
public void testReplicateModeStateHandle() {
Map<String, OperatorStateHandle.StateMetaInfo> metaInfoMap = new HashMap<>(1);
metaInfoMap.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[] { 0, 23 }, OperatorStateHandle.Mode.BROADCAST));
metaInfoMap.put("t-2", new OperatorStateHandle.StateMetaInfo(new long[] { 42, 64 }, OperatorStateHandle.Mode.BROADCAST));
metaInfoMap.put("t-3", new OperatorStateHandle.StateMetaInfo(new long[] { 72, 83 }, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
OperatorStateHandle osh = new OperatorStateHandle(metaInfoMap, new ByteStreamStateHandle("test", new byte[100]));
OperatorStateRepartitioner repartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE;
List<Collection<OperatorStateHandle>> repartitionedStates = repartitioner.repartitionState(Collections.singletonList(osh), 3);
Map<String, Integer> checkCounts = new HashMap<>(3);
for (Collection<OperatorStateHandle> operatorStateHandles : repartitionedStates) {
for (OperatorStateHandle operatorStateHandle : operatorStateHandles) {
for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> stateNameToMetaInfo : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
String stateName = stateNameToMetaInfo.getKey();
Integer count = checkCounts.get(stateName);
if (null == count) {
checkCounts.put(stateName, 1);
} else {
checkCounts.put(stateName, 1 + count);
}
OperatorStateHandle.StateMetaInfo stateMetaInfo = stateNameToMetaInfo.getValue();
if (OperatorStateHandle.Mode.SPLIT_DISTRIBUTE.equals(stateMetaInfo.getDistributionMode())) {
Assert.assertEquals(1, stateNameToMetaInfo.getValue().getOffsets().length);
} else {
Assert.assertEquals(2, stateNameToMetaInfo.getValue().getOffsets().length);
}
}
}
}
Assert.assertEquals(3, checkCounts.size());
Assert.assertEquals(3, checkCounts.get("t-1").intValue());
Assert.assertEquals(3, checkCounts.get("t-2").intValue());
Assert.assertEquals(2, checkCounts.get("t-3").intValue());
}
Aggregations