use of org.apache.flink.runtime.operators.util.TaskConfig in project flink by apache.
the class IterationSynchronizationSinkTask method invoke.
// --------------------------------------------------------------------------------------------
@Override
public void invoke() throws Exception {
this.headEventReader = new MutableRecordReader<>(getEnvironment().getInputGate(0), getEnvironment().getTaskManagerInfo().getTmpDirectories());
TaskConfig taskConfig = new TaskConfig(getTaskConfiguration());
// store all aggregators
this.aggregators = new HashMap<>();
for (AggregatorWithName<?> aggWithName : taskConfig.getIterationAggregators(getUserCodeClassLoader())) {
aggregators.put(aggWithName.getName(), aggWithName.getAggregator());
}
// store the aggregator convergence criterion
if (taskConfig.usesConvergenceCriterion()) {
convergenceCriterion = taskConfig.getConvergenceCriterion(getUserCodeClassLoader());
convergenceAggregatorName = taskConfig.getConvergenceCriterionAggregatorName();
Preconditions.checkNotNull(convergenceAggregatorName);
}
// store the default aggregator convergence criterion
if (taskConfig.usesImplicitConvergenceCriterion()) {
implicitConvergenceCriterion = taskConfig.getImplicitConvergenceCriterion(getUserCodeClassLoader());
implicitConvergenceAggregatorName = taskConfig.getImplicitConvergenceCriterionAggregatorName();
Preconditions.checkNotNull(implicitConvergenceAggregatorName);
}
maxNumberOfIterations = taskConfig.getNumberOfIterations();
// set up the event handler
int numEventsTillEndOfSuperstep = taskConfig.getNumberOfEventsUntilInterruptInIterativeGate(0);
eventHandler = new SyncEventHandler(numEventsTillEndOfSuperstep, aggregators, getEnvironment().getUserClassLoader());
headEventReader.registerTaskEventListener(eventHandler, WorkerDoneEvent.class);
IntValue dummy = new IntValue();
while (!terminationRequested()) {
if (log.isInfoEnabled()) {
log.info(formatLogString("starting iteration [" + currentIteration + "]"));
}
// this call listens for events until the end-of-superstep is reached
readHeadEventChannel(dummy);
if (log.isInfoEnabled()) {
log.info(formatLogString("finishing iteration [" + currentIteration + "]"));
}
if (checkForConvergence()) {
if (log.isInfoEnabled()) {
log.info(formatLogString("signaling that all workers are to terminate in iteration [" + currentIteration + "]"));
}
requestTermination();
sendToAllWorkers(new TerminationEvent());
} else {
if (log.isInfoEnabled()) {
log.info(formatLogString("signaling that all workers are done in iteration [" + currentIteration + "]"));
}
AllWorkersDoneEvent allWorkersDoneEvent = new AllWorkersDoneEvent(aggregators);
sendToAllWorkers(allWorkersDoneEvent);
// reset all aggregators
for (Aggregator<?> agg : aggregators.values()) {
agg.reset();
}
currentIteration++;
}
}
}
use of org.apache.flink.runtime.operators.util.TaskConfig in project flink by apache.
the class InputFormatVertex method initializeOnMaster.
@Override
public void initializeOnMaster(ClassLoader loader) throws Exception {
final TaskConfig cfg = new TaskConfig(getConfiguration());
// deserialize from the payload
UserCodeWrapper<InputFormat<?, ?>> wrapper;
try {
wrapper = cfg.getStubWrapper(loader);
} catch (Throwable t) {
throw new Exception("Deserializing the InputFormat (" + formatDescription + ") failed: " + t.getMessage(), t);
}
if (wrapper == null) {
throw new Exception("No input format present in InputFormatVertex's task configuration.");
}
// instantiate, if necessary
InputFormat<?, ?> inputFormat;
try {
inputFormat = wrapper.getUserCodeObject(InputFormat.class, loader);
} catch (Throwable t) {
throw new Exception("Instantiating the InputFormat (" + formatDescription + ") failed: " + t.getMessage(), t);
}
Thread thread = Thread.currentThread();
ClassLoader original = thread.getContextClassLoader();
// configure
try {
thread.setContextClassLoader(loader);
inputFormat.configure(cfg.getStubParameters());
} catch (Throwable t) {
throw new Exception("Configuring the InputFormat (" + formatDescription + ") failed: " + t.getMessage(), t);
} finally {
thread.setContextClassLoader(original);
}
setInputSplitSource(inputFormat);
}
use of org.apache.flink.runtime.operators.util.TaskConfig in project flink by apache.
the class OutputFormatVertex method finalizeOnMaster.
@Override
public void finalizeOnMaster(ClassLoader loader) throws Exception {
final TaskConfig cfg = new TaskConfig(getConfiguration());
UserCodeWrapper<OutputFormat<?>> wrapper;
try {
wrapper = cfg.<OutputFormat<?>>getStubWrapper(loader);
} catch (Throwable t) {
throw new Exception("Deserializing the OutputFormat (" + formatDescription + ") failed: " + t.getMessage(), t);
}
if (wrapper == null) {
throw new Exception("No input format present in InputFormatVertex's task configuration.");
}
OutputFormat<?> outputFormat;
try {
outputFormat = wrapper.getUserCodeObject(OutputFormat.class, loader);
} catch (Throwable t) {
throw new Exception("Instantiating the OutputFormat (" + formatDescription + ") failed: " + t.getMessage(), t);
}
try {
outputFormat.configure(cfg.getStubParameters());
} catch (Throwable t) {
throw new Exception("Configuring the OutputFormat (" + formatDescription + ") failed: " + t.getMessage(), t);
}
if (outputFormat instanceof FinalizeOnMaster) {
((FinalizeOnMaster) outputFormat).finalizeGlobal(getParallelism());
}
}
use of org.apache.flink.runtime.operators.util.TaskConfig in project flink by apache.
the class StreamingJobGraphGenerator method createJobVertex.
private StreamConfig createJobVertex(Integer streamNodeId, Map<Integer, byte[]> hashes, List<Map<Integer, byte[]>> legacyHashes) {
JobVertex jobVertex;
StreamNode streamNode = streamGraph.getStreamNode(streamNodeId);
byte[] hash = hashes.get(streamNodeId);
if (hash == null) {
throw new IllegalStateException("Cannot find node hash. " + "Did you generate them before calling this method?");
}
JobVertexID jobVertexId = new JobVertexID(hash);
List<JobVertexID> legacyJobVertexIds = new ArrayList<>(legacyHashes.size());
for (Map<Integer, byte[]> legacyHash : legacyHashes) {
hash = legacyHash.get(streamNodeId);
if (null != hash) {
legacyJobVertexIds.add(new JobVertexID(hash));
}
}
if (streamNode.getInputFormat() != null) {
jobVertex = new InputFormatVertex(chainedNames.get(streamNodeId), jobVertexId, legacyJobVertexIds);
TaskConfig taskConfig = new TaskConfig(jobVertex.getConfiguration());
taskConfig.setStubWrapper(new UserCodeObjectWrapper<Object>(streamNode.getInputFormat()));
} else {
jobVertex = new JobVertex(chainedNames.get(streamNodeId), jobVertexId, legacyJobVertexIds);
}
jobVertex.setResources(chainedMinResources.get(streamNodeId), chainedPreferredResources.get(streamNodeId));
jobVertex.setInvokableClass(streamNode.getJobVertexClass());
int parallelism = streamNode.getParallelism();
if (parallelism == ExecutionConfig.PARALLELISM_DEFAULT) {
parallelism = defaultParallelism;
}
jobVertex.setParallelism(parallelism);
jobVertex.setMaxParallelism(streamNode.getMaxParallelism());
if (LOG.isDebugEnabled()) {
LOG.debug("Parallelism set: {} for {}", parallelism, streamNodeId);
}
jobVertices.put(streamNodeId, jobVertex);
builtVertices.add(streamNodeId);
jobGraph.addVertex(jobVertex);
return new StreamConfig(jobVertex.getConfiguration());
}
use of org.apache.flink.runtime.operators.util.TaskConfig in project flink by apache.
the class ChainedAllReduceDriverTest method testMapTask.
@Test
public void testMapTask() {
final int keyCnt = 100;
final int valCnt = 20;
final double memoryFraction = 1.0;
try {
// environment
initEnvironment(MEMORY_MANAGER_SIZE, NETWORK_BUFFER_SIZE);
mockEnv.getExecutionConfig().enableObjectReuse();
addInput(new UniformRecordGenerator(keyCnt, valCnt, false), 0);
addOutput(this.outList);
// chained reduce config
{
final TaskConfig reduceConfig = new TaskConfig(new Configuration());
// input
reduceConfig.addInputToGroup(0);
reduceConfig.setInputSerializer(serFact, 0);
// output
reduceConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
reduceConfig.setOutputSerializer(serFact);
// driver
reduceConfig.setDriverStrategy(DriverStrategy.ALL_REDUCE);
reduceConfig.setDriverComparator(compFact, 0);
reduceConfig.setDriverComparator(compFact, 1);
reduceConfig.setRelativeMemoryDriver(memoryFraction);
// udf
reduceConfig.setStubWrapper(new UserCodeClassWrapper<>(MockReduceStub.class));
getTaskConfig().addChainedTask(ChainedAllReduceDriver.class, reduceConfig, "reduce");
}
// chained map+reduce
{
BatchTask<FlatMapFunction<Record, Record>, Record> testTask = new BatchTask<>();
registerTask(testTask, FlatMapDriver.class, MockMapStub.class);
try {
testTask.invoke();
} catch (Exception e) {
e.printStackTrace();
Assert.fail("Invoke method caused exception.");
}
}
int sumTotal = valCnt * keyCnt * (keyCnt - 1) / 2;
Assert.assertEquals(1, this.outList.size());
Assert.assertEquals(sumTotal, this.outList.get(0).getField(0, IntValue.class).getValue());
} catch (Exception e) {
e.printStackTrace();
Assert.fail(e.getMessage());
}
}
Aggregations