the class StateAssignmentOperation method assignTaskStatesToOperatorInstances.
private static void assignTaskStatesToOperatorInstances(TaskState taskState, ExecutionJobVertex executionJobVertex) {
final int oldParallelism = taskState.getParallelism();
final int newParallelism = executionJobVertex.getParallelism();
List<KeyGroupRange> keyGroupPartitions = createKeyGroupPartitions(executionJobVertex.getMaxParallelism(), newParallelism);
final int chainLength = taskState.getChainLength();
// operator chain idx -> list of the stored op states from all parallel instances for this chain idx
@SuppressWarnings("unchecked") List<OperatorStateHandle>[] parallelOpStatesBackend = new List[chainLength];
@SuppressWarnings("unchecked") List<OperatorStateHandle>[] parallelOpStatesStream = new List[chainLength];
List<KeyGroupsStateHandle> parallelKeyedStatesBackend = new ArrayList<>(oldParallelism);
List<KeyGroupsStateHandle> parallelKeyedStateStream = new ArrayList<>(oldParallelism);
for (int p = 0; p < oldParallelism; ++p) {
SubtaskState subtaskState = taskState.getState(p);
if (null != subtaskState) {
collectParallelStatesByChainOperator(parallelOpStatesBackend, subtaskState.getManagedOperatorState());
collectParallelStatesByChainOperator(parallelOpStatesStream, subtaskState.getRawOperatorState());
KeyGroupsStateHandle keyedStateBackend = subtaskState.getManagedKeyedState();
if (null != keyedStateBackend) {
KeyGroupsStateHandle keyedStateStream = subtaskState.getRawKeyedState();
if (null != keyedStateStream) {
// operator chain index -> lists with collected states (one collection for each parallel subtasks)
@SuppressWarnings("unchecked") List<Collection<OperatorStateHandle>>[] partitionedParallelStatesBackend = new List[chainLength];
@SuppressWarnings("unchecked") List<Collection<OperatorStateHandle>>[] partitionedParallelStatesStream = new List[chainLength];
//TODO here we can employ different redistribution strategies for state, e.g. union state.
// For now we only offer round robin as the default.
OperatorStateRepartitioner opStateRepartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE;
for (int chainIdx = 0; chainIdx < chainLength; ++chainIdx) {
List<OperatorStateHandle> chainOpParallelStatesBackend = parallelOpStatesBackend[chainIdx];
List<OperatorStateHandle> chainOpParallelStatesStream = parallelOpStatesStream[chainIdx];
partitionedParallelStatesBackend[chainIdx] = applyRepartitioner(opStateRepartitioner, chainOpParallelStatesBackend, oldParallelism, newParallelism);
partitionedParallelStatesStream[chainIdx] = applyRepartitioner(opStateRepartitioner, chainOpParallelStatesStream, oldParallelism, newParallelism);
for (int subTaskIdx = 0; subTaskIdx < newParallelism; ++subTaskIdx) {
// non-partitioned state
ChainedStateHandle<StreamStateHandle> nonPartitionableState = null;
if (oldParallelism == newParallelism) {
if (taskState.getState(subTaskIdx) != null) {
nonPartitionableState = taskState.getState(subTaskIdx).getLegacyOperatorState();
// partitionable state
@SuppressWarnings("unchecked") Collection<OperatorStateHandle>[] iab = new Collection[chainLength];
@SuppressWarnings("unchecked") Collection<OperatorStateHandle>[] ias = new Collection[chainLength];
List<Collection<OperatorStateHandle>> operatorStateFromBackend = Arrays.asList(iab);
List<Collection<OperatorStateHandle>> operatorStateFromStream = Arrays.asList(ias);
for (int chainIdx = 0; chainIdx < partitionedParallelStatesBackend.length; ++chainIdx) {
List<Collection<OperatorStateHandle>> redistributedOpStateBackend = partitionedParallelStatesBackend[chainIdx];
List<Collection<OperatorStateHandle>> redistributedOpStateStream = partitionedParallelStatesStream[chainIdx];
if (redistributedOpStateBackend != null) {
operatorStateFromBackend.set(chainIdx, redistributedOpStateBackend.get(subTaskIdx));
if (redistributedOpStateStream != null) {
operatorStateFromStream.set(chainIdx, redistributedOpStateStream.get(subTaskIdx));
Execution currentExecutionAttempt = executionJobVertex.getTaskVertices()[subTaskIdx].getCurrentExecutionAttempt();
List<KeyGroupsStateHandle> newKeyedStatesBackend;
List<KeyGroupsStateHandle> newKeyedStateStream;
if (oldParallelism == newParallelism) {
SubtaskState subtaskState = taskState.getState(subTaskIdx);
if (subtaskState != null) {
KeyGroupsStateHandle oldKeyedStatesBackend = subtaskState.getManagedKeyedState();
KeyGroupsStateHandle oldKeyedStatesStream = subtaskState.getRawKeyedState();
newKeyedStatesBackend = oldKeyedStatesBackend != null ? Collections.singletonList(oldKeyedStatesBackend) : null;
newKeyedStateStream = oldKeyedStatesStream != null ? Collections.singletonList(oldKeyedStatesStream) : null;
} else {
newKeyedStatesBackend = null;
newKeyedStateStream = null;
} else {
KeyGroupRange subtaskKeyGroupIds = keyGroupPartitions.get(subTaskIdx);
newKeyedStatesBackend = getKeyGroupsStateHandles(parallelKeyedStatesBackend, subtaskKeyGroupIds);
newKeyedStateStream = getKeyGroupsStateHandles(parallelKeyedStateStream, subtaskKeyGroupIds);
TaskStateHandles taskStateHandles = new TaskStateHandles(nonPartitionableState, operatorStateFromBackend, operatorStateFromStream, newKeyedStatesBackend, newKeyedStateStream);
the class SavepointLoader method loadAndValidateSavepoint.
* Loads a savepoint back as a {@link CompletedCheckpoint}.
* <p>This method verifies that tasks and parallelism still match the savepoint parameters.
* @param jobId The JobID of the job to load the savepoint for.
* @param tasks Tasks that will possibly be reset
* @param savepointPath The path of the savepoint to rollback to
* @param classLoader The class loader to resolve serialized classes in legacy savepoint versions.
* @param allowNonRestoredState Allow to skip checkpoint state that cannot be mapped
* to any job vertex in tasks.
* @throws IllegalStateException If mismatch between program and savepoint state
* @throws IOException If savepoint store failure
public static CompletedCheckpoint loadAndValidateSavepoint(JobID jobId, Map<JobVertexID, ExecutionJobVertex> tasks, String savepointPath, ClassLoader classLoader, boolean allowNonRestoredState) throws IOException {
// (1) load the savepoint
final Tuple2<Savepoint, StreamStateHandle> savepointAndHandle = SavepointStore.loadSavepointWithHandle(savepointPath, classLoader);
final Savepoint savepoint = savepointAndHandle.f0;
final StreamStateHandle metadataHandle = savepointAndHandle.f1;
final Map<JobVertexID, TaskState> taskStates = new HashMap<>(savepoint.getTaskStates().size());
boolean expandedToLegacyIds = false;
// (2) validate it (parallelism, etc)
for (TaskState taskState : savepoint.getTaskStates()) {
ExecutionJobVertex executionJobVertex = tasks.get(taskState.getJobVertexID());
// for example as generated from older flink versions, to provide backwards compatibility.
if (executionJobVertex == null && !expandedToLegacyIds) {
tasks = ExecutionJobVertex.includeLegacyJobVertexIDs(tasks);
executionJobVertex = tasks.get(taskState.getJobVertexID());
expandedToLegacyIds = true;"Could not find ExecutionJobVertex. Including legacy JobVertexIDs in search.");
if (executionJobVertex != null) {
if (executionJobVertex.getMaxParallelism() == taskState.getMaxParallelism() || !executionJobVertex.isMaxParallelismConfigured()) {
taskStates.put(taskState.getJobVertexID(), taskState);
} else {
String msg = String.format("Failed to rollback to savepoint %s. " + "Max parallelism mismatch between savepoint state and new program. " + "Cannot map operator %s with max parallelism %d to new program with " + "max parallelism %d. This indicates that the program has been changed " + "in a non-compatible way after the savepoint.", savepoint, taskState.getJobVertexID(), taskState.getMaxParallelism(), executionJobVertex.getMaxParallelism());
throw new IllegalStateException(msg);
} else if (allowNonRestoredState) {"Skipping savepoint state for operator {}.", taskState.getJobVertexID());
} else {
String msg = String.format("Failed to rollback to savepoint %s. " + "Cannot map savepoint state for operator %s to the new program, " + "because the operator is not available in the new program. If " + "you want to allow to skip this, you can set the --allowNonRestoredState " + "option on the CLI.", savepointPath, taskState.getJobVertexID());
throw new IllegalStateException(msg);
// (3) convert to checkpoint so the system can fall back to it
CheckpointProperties props = CheckpointProperties.forStandardSavepoint();
return new CompletedCheckpoint(jobId, savepoint.getCheckpointId(), 0L, 0L, taskStates, props, metadataHandle, savepointPath);
the class SavepointStore method loadSavepointWithHandle.
* Loads the savepoint at the specified path. This methods returns the savepoint, as well as the
* handle to the metadata.
* @param savepointFileOrDirectory Path to the parent savepoint directory or the meta data file.
* @param classLoader The class loader used to resolve serialized classes from legacy savepoint formats.
* @return The loaded savepoint
* @throws IOException Failures during load are forwarded
public static Tuple2<Savepoint, StreamStateHandle> loadSavepointWithHandle(String savepointFileOrDirectory, ClassLoader classLoader) throws IOException {
checkNotNull(savepointFileOrDirectory, "savepointFileOrDirectory");
checkNotNull(classLoader, "classLoader");
Path path = new Path(savepointFileOrDirectory);"Loading savepoint from {}", path);
FileSystem fs = FileSystem.get(path.toUri());
FileStatus status = fs.getFileStatus(path);
// If this is a directory, we need to find the meta data file
if (status.isDir()) {
Path candidatePath = new Path(path, SAVEPOINT_METADATA_FILE);
if (fs.exists(candidatePath)) {
path = candidatePath;"Using savepoint file in {}", path);
} else {
throw new IOException("Cannot find meta data file in directory " + path + ". Please try to load the savepoint directly from the meta data file " + "instead of the directory.");
// load the savepoint
final Savepoint savepoint;
try (DataInputStream dis = new DataInputViewStreamWrapper( {
int magicNumber = dis.readInt();
if (magicNumber == MAGIC_NUMBER) {
int version = dis.readInt();
SavepointSerializer<?> serializer = SavepointSerializers.getSerializer(version);
savepoint = serializer.deserialize(dis, classLoader);
} else {
throw new RuntimeException("Unexpected magic number. This can have multiple reasons: " + "(1) You are trying to load a Flink 1.0 savepoint, which is not supported by this " + "version of Flink. (2) The file you were pointing to is not a savepoint at all. " + "(3) The savepoint file has been corrupted.");
// construct the stream handle to the metadata file
// we get the size best-effort
long size = 0;
try {
size = fs.getFileStatus(path).getLen();
} catch (Exception ignored) {
// we don't know the size, but we don't want to fail the savepoint loading for that
StreamStateHandle metadataHandle = new FileStateHandle(path, size);
return new Tuple2<>(savepoint, metadataHandle);
the class PendingCheckpoint method acknowledgeTask.
* Acknowledges the task with the given execution attempt id and the given subtask state.
* @param executionAttemptId of the acknowledged task
* @param subtaskState of the acknowledged task
* @param metrics Checkpoint metrics for the stats
* @return TaskAcknowledgeResult of the operation
public TaskAcknowledgeResult acknowledgeTask(ExecutionAttemptID executionAttemptId, SubtaskState subtaskState, CheckpointMetrics metrics) {
synchronized (lock) {
if (discarded) {
return TaskAcknowledgeResult.DISCARDED;
final ExecutionVertex vertex = notYetAcknowledgedTasks.remove(executionAttemptId);
if (vertex == null) {
if (acknowledgedTasks.contains(executionAttemptId)) {
return TaskAcknowledgeResult.DUPLICATE;
} else {
return TaskAcknowledgeResult.UNKNOWN;
} else {
JobVertexID jobVertexID = vertex.getJobvertexId();
int subtaskIndex = vertex.getParallelSubtaskIndex();
long ackTimestamp = System.currentTimeMillis();
long stateSize = 0;
if (null != subtaskState) {
TaskState taskState = taskStates.get(jobVertexID);
if (null == taskState) {
@SuppressWarnings("deprecation") ChainedStateHandle<StreamStateHandle> nonPartitionedState = subtaskState.getLegacyOperatorState();
ChainedStateHandle<OperatorStateHandle> partitioneableState = subtaskState.getManagedOperatorState();
//TODO this should go away when we remove chained state, assigning state to operators directly instead
int chainLength;
if (nonPartitionedState != null) {
chainLength = nonPartitionedState.getLength();
} else if (partitioneableState != null) {
chainLength = partitioneableState.getLength();
} else {
chainLength = 1;
taskState = new TaskState(jobVertexID, vertex.getTotalNumberOfParallelSubtasks(), vertex.getMaxParallelism(), chainLength);
taskStates.put(jobVertexID, taskState);
taskState.putState(subtaskIndex, subtaskState);
stateSize = subtaskState.getStateSize();
// publish the checkpoint statistics
// to prevent null-pointers from concurrent modification, copy reference onto stack
final PendingCheckpointStats statsCallback = this.statsCallback;
if (statsCallback != null) {
// Do this in millis because the web frontend works with them
long alignmentDurationMillis = metrics.getAlignmentDurationNanos() / 1_000_000;
SubtaskStateStats subtaskStateStats = new SubtaskStateStats(subtaskIndex, ackTimestamp, stateSize, metrics.getSyncDurationMillis(), metrics.getAsyncDurationMillis(), metrics.getBytesBufferedInAlignment(), alignmentDurationMillis);
statsCallback.reportSubtaskStats(jobVertexID, subtaskStateStats);
return TaskAcknowledgeResult.SUCCESS;
the class RoundRobinOperatorStateRepartitioner method repartition.
* Repartition all named states.
private List<Map<StreamStateHandle, OperatorStateHandle>> repartition(GroupByStateNameResults nameToStateByMode, int parallelism) {
// We will use this to merge w.r.t. StreamStateHandles for each parallel subtask inside the maps
List<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList = new ArrayList<>(parallelism);
// Initialize
for (int i = 0; i < parallelism; ++i) {
mergeMapList.add(new HashMap<StreamStateHandle, OperatorStateHandle>());
// Start with the state handles we distribute round robin by splitting by offsets
Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> distributeNameToState = nameToStateByMode.getByMode(OperatorStateHandle.Mode.SPLIT_DISTRIBUTE);
int startParallelOp = 0;
// Iterate all named states and repartition one named state at a time per iteration
for (Map.Entry<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> e : distributeNameToState.entrySet()) {
List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> current = e.getValue();
// Determine actual number of partitions for this named state
int totalPartitions = 0;
for (Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> offsets : current) {
totalPartitions += offsets.f1.getOffsets().length;
// Repartition the state across the parallel operator instances
int lstIdx = 0;
int offsetIdx = 0;
int baseFraction = totalPartitions / parallelism;
int remainder = totalPartitions % parallelism;
int newStartParallelOp = startParallelOp;
for (int i = 0; i < parallelism; ++i) {
// Preparation: calculate the actual index considering wrap around
int parallelOpIdx = (i + startParallelOp) % parallelism;
// Now calculate the number of partitions we will assign to the parallel instance in this round ...
int numberOfPartitionsToAssign = baseFraction;
// ... and distribute odd partitions while we still have some, one at a time
if (remainder > 0) {
} else if (remainder == 0) {
// We are out of odd partitions now and begin our next redistribution round with the current
// parallel operator to ensure fair load balance
newStartParallelOp = parallelOpIdx;
// Now start collection the partitions for the parallel instance into this list
List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> parallelOperatorState = new ArrayList<>();
while (numberOfPartitionsToAssign > 0) {
Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> handleWithOffsets = current.get(lstIdx);
long[] offsets = handleWithOffsets.f1.getOffsets();
int remaining = offsets.length - offsetIdx;
// Repartition offsets
long[] offs;
if (remaining > numberOfPartitionsToAssign) {
offs = Arrays.copyOfRange(offsets, offsetIdx, offsetIdx + numberOfPartitionsToAssign);
offsetIdx += numberOfPartitionsToAssign;
} else {
// GC
handleWithOffsets.f1 = null;
offs = Arrays.copyOfRange(offsets, offsetIdx, offsets.length);
offsetIdx = 0;
parallelOperatorState.add(new Tuple2<>(handleWithOffsets.f0, new OperatorStateHandle.StateMetaInfo(offs, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)));
numberOfPartitionsToAssign -= remaining;
// As a last step we merge partitions that use the same StreamStateHandle in a single
// OperatorStateHandle
Map<StreamStateHandle, OperatorStateHandle> mergeMap = mergeMapList.get(parallelOpIdx);
OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithOffsets.f0);
if (operatorStateHandle == null) {
operatorStateHandle = new OperatorStateHandle(new HashMap<String, OperatorStateHandle.StateMetaInfo>(), handleWithOffsets.f0);
mergeMap.put(handleWithOffsets.f0, operatorStateHandle);
operatorStateHandle.getStateNameToPartitionOffsets().put(e.getKey(), new OperatorStateHandle.StateMetaInfo(offs, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
startParallelOp = newStartParallelOp;
// Now we also add the state handles marked for broadcast to all parallel instances
Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> broadcastNameToState = nameToStateByMode.getByMode(OperatorStateHandle.Mode.BROADCAST);
for (int i = 0; i < parallelism; ++i) {
Map<StreamStateHandle, OperatorStateHandle> mergeMap = mergeMapList.get(i);
for (Map.Entry<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> e : broadcastNameToState.entrySet()) {
List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> current = e.getValue();
for (Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> handleWithMetaInfo : current) {
OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithMetaInfo.f0);
if (operatorStateHandle == null) {
operatorStateHandle = new OperatorStateHandle(new HashMap<String, OperatorStateHandle.StateMetaInfo>(), handleWithMetaInfo.f0);
mergeMap.put(handleWithMetaInfo.f0, operatorStateHandle);
operatorStateHandle.getStateNameToPartitionOffsets().put(e.getKey(), handleWithMetaInfo.f1);
return mergeMapList;