use of org.apache.flink.api.common.aggregators.LongSumAggregator in project flink by apache.
the class AggregatorConvergenceITCase method testDeltaConnectedComponentsWithParametrizableConvergence.
@Test
public void testDeltaConnectedComponentsWithParametrizableConvergence() throws Exception {
// name of the aggregator that checks for convergence
final String UPDATED_ELEMENTS = "updated.elements.aggr";
// the iteration stops if less than this number of elements change value
final long convergence_threshold = 3;
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple2<Long, Long>> initialSolutionSet = env.fromCollection(verticesInput);
DataSet<Tuple2<Long, Long>> edges = env.fromCollection(edgesInput);
DeltaIteration<Tuple2<Long, Long>, Tuple2<Long, Long>> iteration = initialSolutionSet.iterateDelta(initialSolutionSet, 10, 0);
// register the convergence criterion
iteration.registerAggregationConvergenceCriterion(UPDATED_ELEMENTS, new LongSumAggregator(), new UpdatedElementsConvergenceCriterion(convergence_threshold));
DataSet<Tuple2<Long, Long>> verticesWithNewComponents = iteration.getWorkset().join(edges).where(0).equalTo(0).with(new NeighborWithComponentIDJoin()).groupBy(0).min(1);
DataSet<Tuple2<Long, Long>> updatedComponentId = verticesWithNewComponents.join(iteration.getSolutionSet()).where(0).equalTo(0).flatMap(new MinimumIdFilter(UPDATED_ELEMENTS));
List<Tuple2<Long, Long>> result = iteration.closeWith(updatedComponentId, updatedComponentId).collect();
Collections.sort(result, new JavaProgramTestBase.TupleComparator<Tuple2<Long, Long>>());
assertEquals(expectedResult, result);
}
use of org.apache.flink.api.common.aggregators.LongSumAggregator in project flink by apache.
the class AggregatorsITCase method testAggregatorWithParameterForIterateDelta.
@Test
public void testAggregatorWithParameterForIterateDelta() throws Exception {
/*
* Test aggregator with parameter for iterateDelta
*/
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(parallelism);
DataSet<Tuple2<Integer, Integer>> initialSolutionSet = CollectionDataSets.getIntegerDataSet(env).map(new TupleMakerMap());
DeltaIteration<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> iteration = initialSolutionSet.iterateDelta(initialSolutionSet, MAX_ITERATIONS, 0);
// register aggregator
LongSumAggregator aggr = new LongSumAggregatorWithParameter(4);
iteration.registerAggregator(NEGATIVE_ELEMENTS_AGGR, aggr);
DataSet<Tuple2<Integer, Integer>> updatedDs = iteration.getWorkset().map(new AggregateMapDelta());
DataSet<Tuple2<Integer, Integer>> newElements = updatedDs.join(iteration.getSolutionSet()).where(0).equalTo(0).flatMap(new UpdateFilter());
DataSet<Tuple2<Integer, Integer>> iterationRes = iteration.closeWith(newElements, newElements);
DataSet<Integer> result = iterationRes.map(new ProjectSecondMapper());
result.writeAsText(resultPath);
env.execute();
expected = "1\n" + "2\n" + "2\n" + "3\n" + "3\n" + "3\n" + "4\n" + "4\n" + "4\n" + "4\n" + "5\n" + "5\n" + "5\n" + "5\n" + "5\n";
}
use of org.apache.flink.api.common.aggregators.LongSumAggregator in project flink by apache.
the class JobGraphGenerator method finalizeWorksetIteration.
private void finalizeWorksetIteration(IterationDescriptor descr) {
final WorksetIterationPlanNode iterNode = (WorksetIterationPlanNode) descr.getIterationNode();
final JobVertex headVertex = descr.getHeadTask();
final TaskConfig headConfig = new TaskConfig(headVertex.getConfiguration());
final TaskConfig headFinalOutputConfig = descr.getHeadFinalResultConfig();
// ------------ finalize the head config with the final outputs and the sync gate ------------
{
final int numStepFunctionOuts = headConfig.getNumOutputs();
final int numFinalOuts = headFinalOutputConfig.getNumOutputs();
if (numStepFunctionOuts == 0) {
throw new CompilerException("The workset iteration has no operation on the workset inside the step function.");
}
headConfig.setIterationHeadFinalOutputConfig(headFinalOutputConfig);
headConfig.setIterationHeadIndexOfSyncOutput(numStepFunctionOuts + numFinalOuts);
final double relativeMemory = iterNode.getRelativeMemoryPerSubTask();
if (relativeMemory <= 0) {
throw new CompilerException("Bug: No memory has been assigned to the workset iteration.");
}
headConfig.setIsWorksetIteration();
headConfig.setRelativeBackChannelMemory(relativeMemory / 2);
headConfig.setRelativeSolutionSetMemory(relativeMemory / 2);
// set the solution set serializer and comparator
headConfig.setSolutionSetSerializer(iterNode.getSolutionSetSerializer());
headConfig.setSolutionSetComparator(iterNode.getSolutionSetComparator());
}
// --------------------------- create the sync task ---------------------------
final TaskConfig syncConfig;
{
final JobVertex sync = new JobVertex("Sync (" + iterNode.getNodeName() + ")");
sync.setResources(iterNode.getMinResources(), iterNode.getPreferredResources());
sync.setInvokableClass(IterationSynchronizationSinkTask.class);
sync.setParallelism(1);
sync.setMaxParallelism(1);
this.auxVertices.add(sync);
syncConfig = new TaskConfig(sync.getConfiguration());
syncConfig.setGateIterativeWithNumberOfEventsUntilInterrupt(0, headVertex.getParallelism());
// set the number of iteration / convergence criterion for the sync
final int maxNumIterations = iterNode.getIterationNode().getIterationContract().getMaximumNumberOfIterations();
if (maxNumIterations < 1) {
throw new CompilerException("Cannot create workset iteration with unspecified maximum number of iterations.");
}
syncConfig.setNumberOfIterations(maxNumIterations);
// connect the sync task
sync.connectNewDataSetAsInput(headVertex, DistributionPattern.POINTWISE, ResultPartitionType.PIPELINED);
}
// ----------------------------- create the iteration tails -----------------------------
// ----------------------- for next workset and solution set delta-----------------------
{
// we have three possible cases:
// 1) Two tails, one for workset update, one for solution set update
// 2) One tail for workset update, solution set update happens in an intermediate task
// 3) One tail for solution set update, workset update happens in an intermediate task
final PlanNode nextWorksetNode = iterNode.getNextWorkSetPlanNode();
final PlanNode solutionDeltaNode = iterNode.getSolutionSetDeltaPlanNode();
final boolean hasWorksetTail = nextWorksetNode.getOutgoingChannels().isEmpty();
final boolean hasSolutionSetTail = (!iterNode.isImmediateSolutionSetUpdate()) || (!hasWorksetTail);
{
// get the vertex for the workset update
final TaskConfig worksetTailConfig;
JobVertex nextWorksetVertex = this.vertices.get(nextWorksetNode);
if (nextWorksetVertex == null) {
// nextWorksetVertex is chained
TaskInChain taskInChain = this.chainedTasks.get(nextWorksetNode);
if (taskInChain == null) {
throw new CompilerException("Bug: Next workset node not found as vertex or chained task.");
}
nextWorksetVertex = taskInChain.getContainingVertex();
worksetTailConfig = taskInChain.getTaskConfig();
} else {
worksetTailConfig = new TaskConfig(nextWorksetVertex.getConfiguration());
}
// mark the node to perform workset updates
worksetTailConfig.setIsWorksetIteration();
worksetTailConfig.setIsWorksetUpdate();
if (hasWorksetTail) {
nextWorksetVertex.setInvokableClass(IterationTailTask.class);
worksetTailConfig.setOutputSerializer(iterNode.getWorksetSerializer());
}
}
{
final TaskConfig solutionDeltaConfig;
JobVertex solutionDeltaVertex = this.vertices.get(solutionDeltaNode);
if (solutionDeltaVertex == null) {
// last op is chained
TaskInChain taskInChain = this.chainedTasks.get(solutionDeltaNode);
if (taskInChain == null) {
throw new CompilerException("Bug: Solution Set Delta not found as vertex or chained task.");
}
solutionDeltaVertex = taskInChain.getContainingVertex();
solutionDeltaConfig = taskInChain.getTaskConfig();
} else {
solutionDeltaConfig = new TaskConfig(solutionDeltaVertex.getConfiguration());
}
solutionDeltaConfig.setIsWorksetIteration();
solutionDeltaConfig.setIsSolutionSetUpdate();
if (hasSolutionSetTail) {
solutionDeltaVertex.setInvokableClass(IterationTailTask.class);
solutionDeltaConfig.setOutputSerializer(iterNode.getSolutionSetSerializer());
// tell the head that it needs to wait for the solution set updates
headConfig.setWaitForSolutionSetUpdate();
} else {
// no tail, intermediate update. must be immediate update
if (!iterNode.isImmediateSolutionSetUpdate()) {
throw new CompilerException("A solution set update without dedicated tail is not set to perform immediate updates.");
}
solutionDeltaConfig.setIsSolutionSetUpdateWithoutReprobe();
}
}
}
// ------------------- register the aggregators -------------------
AggregatorRegistry aggs = iterNode.getIterationNode().getIterationContract().getAggregators();
Collection<AggregatorWithName<?>> allAggregators = aggs.getAllRegisteredAggregators();
for (AggregatorWithName<?> agg : allAggregators) {
if (agg.getName().equals(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME)) {
throw new CompilerException("User defined aggregator used the same name as built-in workset " + "termination check aggregator: " + WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME);
}
}
headConfig.addIterationAggregators(allAggregators);
syncConfig.addIterationAggregators(allAggregators);
String convAggName = aggs.getConvergenceCriterionAggregatorName();
ConvergenceCriterion<?> convCriterion = aggs.getConvergenceCriterion();
if (convCriterion != null || convAggName != null) {
if (convCriterion == null) {
throw new CompilerException("Error: Convergence criterion aggregator set, but criterion is null.");
}
if (convAggName == null) {
throw new CompilerException("Error: Aggregator convergence criterion set, but aggregator is null.");
}
syncConfig.setConvergenceCriterion(convAggName, convCriterion);
}
headConfig.addIterationAggregator(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new LongSumAggregator());
syncConfig.addIterationAggregator(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new LongSumAggregator());
syncConfig.setImplicitConvergenceCriterion(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new WorksetEmptyConvergenceCriterion());
}
use of org.apache.flink.api.common.aggregators.LongSumAggregator in project flink by apache.
the class SpargelTranslationTest method testTranslationPlainEdges.
@Test
public void testTranslationPlainEdges() {
try {
final String ITERATION_NAME = "Test Name";
final String AGGREGATOR_NAME = "AggregatorName";
final String BC_SET_MESSAGES_NAME = "borat messages";
final String BC_SET_UPDATES_NAME = "borat updates";
final int NUM_ITERATIONS = 13;
final int ITERATION_parallelism = 77;
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Long> bcMessaging = env.fromElements(1L);
DataSet<Long> bcUpdate = env.fromElements(1L);
DataSet<Vertex<String, Double>> result;
// ------------ construct the test program ------------------
{
DataSet<Tuple2<String, Double>> initialVertices = env.fromElements(new Tuple2<>("abc", 3.44));
DataSet<Tuple2<String, String>> edges = env.fromElements(new Tuple2<>("a", "c"));
Graph<String, Double, NullValue> graph = Graph.fromTupleDataSet(initialVertices, edges.map(new MapFunction<Tuple2<String, String>, Tuple3<String, String, NullValue>>() {
public Tuple3<String, String, NullValue> map(Tuple2<String, String> edge) {
return new Tuple3<>(edge.f0, edge.f1, NullValue.getInstance());
}
}), env);
ScatterGatherConfiguration parameters = new ScatterGatherConfiguration();
parameters.addBroadcastSetForScatterFunction(BC_SET_MESSAGES_NAME, bcMessaging);
parameters.addBroadcastSetForGatherFunction(BC_SET_UPDATES_NAME, bcUpdate);
parameters.setName(ITERATION_NAME);
parameters.setParallelism(ITERATION_parallelism);
parameters.registerAggregator(AGGREGATOR_NAME, new LongSumAggregator());
result = graph.runScatterGatherIteration(new MessageFunctionNoEdgeValue(), new UpdateFunction(), NUM_ITERATIONS, parameters).getVertices();
result.output(new DiscardingOutputFormat<Vertex<String, Double>>());
}
// ------------- validate the java program ----------------
assertTrue(result instanceof DeltaIterationResultSet);
DeltaIterationResultSet<?, ?> resultSet = (DeltaIterationResultSet<?, ?>) result;
DeltaIteration<?, ?> iteration = resultSet.getIterationHead();
// check the basic iteration properties
assertEquals(NUM_ITERATIONS, resultSet.getMaxIterations());
assertArrayEquals(new int[] { 0 }, resultSet.getKeyPositions());
assertEquals(ITERATION_parallelism, iteration.getParallelism());
assertEquals(ITERATION_NAME, iteration.getName());
assertEquals(AGGREGATOR_NAME, iteration.getAggregators().getAllRegisteredAggregators().iterator().next().getName());
// validate that the semantic properties are set as they should
TwoInputUdfOperator<?, ?, ?, ?> solutionSetJoin = (TwoInputUdfOperator<?, ?, ?, ?>) resultSet.getNextWorkset();
assertTrue(solutionSetJoin.getSemanticProperties().getForwardingTargetFields(0, 0).contains(0));
assertTrue(solutionSetJoin.getSemanticProperties().getForwardingTargetFields(1, 0).contains(0));
TwoInputUdfOperator<?, ?, ?, ?> edgesJoin = (TwoInputUdfOperator<?, ?, ?, ?>) solutionSetJoin.getInput1();
// validate that the broadcast sets are forwarded
assertEquals(bcUpdate, solutionSetJoin.getBroadcastSets().get(BC_SET_UPDATES_NAME));
assertEquals(bcMessaging, edgesJoin.getBroadcastSets().get(BC_SET_MESSAGES_NAME));
} catch (Exception e) {
System.err.println(e.getMessage());
e.printStackTrace();
fail(e.getMessage());
}
}
use of org.apache.flink.api.common.aggregators.LongSumAggregator in project flink by apache.
the class GatherSumApplyConfigurationITCase method testRunWithConfiguration.
@Test
public void testRunWithConfiguration() throws Exception {
/*
* Test Graph's runGatherSumApplyIteration when configuration parameters are provided
*/
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
Graph<Long, Long, Long> graph = Graph.fromCollection(TestGraphUtils.getLongLongVertices(), TestGraphUtils.getLongLongEdges(), env).mapVertices(new AssignOneMapper());
// create the configuration object
GSAConfiguration parameters = new GSAConfiguration();
parameters.addBroadcastSetForGatherFunction("gatherBcastSet", env.fromElements(1, 2, 3));
parameters.addBroadcastSetForSumFunction("sumBcastSet", env.fromElements(4, 5, 6));
parameters.addBroadcastSetForApplyFunction("applyBcastSet", env.fromElements(7, 8, 9));
parameters.registerAggregator("superstepAggregator", new LongSumAggregator());
parameters.setOptNumVertices(true);
Graph<Long, Long, Long> res = graph.runGatherSumApplyIteration(new Gather(), new Sum(), new Apply(), 10, parameters);
DataSet<Vertex<Long, Long>> data = res.getVertices();
List<Vertex<Long, Long>> result = data.collect();
expectedResult = "1,11\n" + "2,11\n" + "3,11\n" + "4,11\n" + "5,11";
compareResultAsTuples(result, expectedResult);
}
Aggregations