use of org.apache.flink.api.common.operators.util.UserCodeClassWrapper in project flink by apache.
the class ChainTaskTest method testMapTask.
@Test
public void testMapTask() {
final int keyCnt = 100;
final int valCnt = 20;
final double memoryFraction = 1.0;
try {
// environment
initEnvironment(MEMORY_MANAGER_SIZE, NETWORK_BUFFER_SIZE);
addInput(new UniformRecordGenerator(keyCnt, valCnt, false), 0);
addOutput(this.outList);
// chained combine config
{
final TaskConfig combineConfig = new TaskConfig(new Configuration());
// input
combineConfig.addInputToGroup(0);
combineConfig.setInputSerializer(serFact, 0);
// output
combineConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
combineConfig.setOutputSerializer(serFact);
// driver
combineConfig.setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
combineConfig.setDriverComparator(compFact, 0);
combineConfig.setDriverComparator(compFact, 1);
combineConfig.setRelativeMemoryDriver(memoryFraction);
// udf
combineConfig.setStubWrapper(new UserCodeClassWrapper<>(MockCombiningReduceStub.class));
getTaskConfig().addChainedTask(SynchronousChainedCombineDriver.class, combineConfig, "combine");
}
// chained map+combine
{
registerTask(FlatMapDriver.class, MockMapStub.class);
BatchTask<FlatMapFunction<Record, Record>, Record> testTask = new BatchTask<>(this.mockEnv);
try {
testTask.invoke();
} catch (Exception e) {
e.printStackTrace();
Assert.fail("Invoke method caused exception.");
}
}
Assert.assertEquals(keyCnt, this.outList.size());
} catch (Exception e) {
e.printStackTrace();
Assert.fail(e.getMessage());
}
}
use of org.apache.flink.api.common.operators.util.UserCodeClassWrapper in project flink by apache.
the class ChainTaskTest method testFailingMapTask.
@Test
public void testFailingMapTask() {
int keyCnt = 100;
int valCnt = 20;
final long memorySize = 1024 * 1024 * 3;
final int bufferSize = 1014 * 1024;
final double memoryFraction = 1.0;
try {
// environment
initEnvironment(memorySize, bufferSize);
addInput(new UniformRecordGenerator(keyCnt, valCnt, false), 0);
addOutput(this.outList);
// chained combine config
{
final TaskConfig combineConfig = new TaskConfig(new Configuration());
// input
combineConfig.addInputToGroup(0);
combineConfig.setInputSerializer(serFact, 0);
// output
combineConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
combineConfig.setOutputSerializer(serFact);
// driver
combineConfig.setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
combineConfig.setDriverComparator(compFact, 0);
combineConfig.setDriverComparator(compFact, 1);
combineConfig.setRelativeMemoryDriver(memoryFraction);
// udf
combineConfig.setStubWrapper(new UserCodeClassWrapper<>(MockFailingCombineStub.class));
getTaskConfig().addChainedTask(SynchronousChainedCombineDriver.class, combineConfig, "combine");
}
// chained map+combine
{
registerTask(FlatMapDriver.class, MockMapStub.class);
final BatchTask<FlatMapFunction<Record, Record>, Record> testTask = new BatchTask<>(this.mockEnv);
boolean stubFailed = false;
try {
testTask.invoke();
} catch (Exception e) {
stubFailed = true;
}
Assert.assertTrue("Function exception was not forwarded.", stubFailed);
}
} catch (Exception e) {
e.printStackTrace();
Assert.fail(e.getMessage());
}
}
use of org.apache.flink.api.common.operators.util.UserCodeClassWrapper in project flink by apache.
the class ChainedAllReduceDriverTest method testMapTask.
@Test
public void testMapTask() throws Exception {
final int keyCnt = 100;
final int valCnt = 20;
final double memoryFraction = 1.0;
// environment
initEnvironment(MEMORY_MANAGER_SIZE, NETWORK_BUFFER_SIZE);
mockEnv.getExecutionConfig().enableObjectReuse();
addInput(new UniformRecordGenerator(keyCnt, valCnt, false), 0);
addOutput(this.outList);
// chained reduce config
{
final TaskConfig reduceConfig = new TaskConfig(new Configuration());
// input
reduceConfig.addInputToGroup(0);
reduceConfig.setInputSerializer(serFact, 0);
// output
reduceConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
reduceConfig.setOutputSerializer(serFact);
// driver
reduceConfig.setDriverStrategy(DriverStrategy.ALL_REDUCE);
reduceConfig.setDriverComparator(compFact, 0);
reduceConfig.setDriverComparator(compFact, 1);
reduceConfig.setRelativeMemoryDriver(memoryFraction);
// udf
reduceConfig.setStubWrapper(new UserCodeClassWrapper<>(MockReduceStub.class));
getTaskConfig().addChainedTask(ChainedAllReduceDriver.class, reduceConfig, "reduce");
}
// chained map+reduce
{
registerTask(FlatMapDriver.class, MockMapStub.class);
BatchTask<FlatMapFunction<Record, Record>, Record> testTask = new BatchTask<>(mockEnv);
testTask.invoke();
}
int sumTotal = valCnt * keyCnt * (keyCnt - 1) / 2;
Assert.assertEquals(1, this.outList.size());
Assert.assertEquals(sumTotal, this.outList.get(0).getField(0, IntValue.class).getValue());
}
Aggregations