use of org.apache.flink.api.common.operators.base.DeltaIterationBase in project flink by apache.
the class DeltaIterationTranslationTest method testCorrectTranslation.
@Test
public void testCorrectTranslation() {
try {
final String jobName = "Test JobName";
final String iterationName = "Test Name";
final String beforeNextWorksetMap = "Some Mapper";
final String aggregatorName = "AggregatorName";
final int[] iterationKeys = new int[] { 2 };
final int numIterations = 13;
final int defaultParallelism = 133;
final int iterationParallelism = 77;
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
// ------------ construct the test program ------------------
{
env.setParallelism(defaultParallelism);
@SuppressWarnings("unchecked") DataSet<Tuple3<Double, Long, String>> initialSolutionSet = env.fromElements(new Tuple3<Double, Long, String>(3.44, 5L, "abc"));
@SuppressWarnings("unchecked") DataSet<Tuple2<Double, String>> initialWorkSet = env.fromElements(new Tuple2<Double, String>(1.23, "abc"));
DeltaIteration<Tuple3<Double, Long, String>, Tuple2<Double, String>> iteration = initialSolutionSet.iterateDelta(initialWorkSet, numIterations, iterationKeys);
iteration.name(iterationName).parallelism(iterationParallelism);
iteration.registerAggregator(aggregatorName, new LongSumAggregator());
// test that multiple workset consumers are supported
DataSet<Tuple2<Double, String>> worksetSelfJoin = iteration.getWorkset().map(new IdentityMapper<Tuple2<Double, String>>()).join(iteration.getWorkset()).where(1).equalTo(1).projectFirst(0, 1);
DataSet<Tuple3<Double, Long, String>> joined = worksetSelfJoin.join(iteration.getSolutionSet()).where(1).equalTo(2).with(new SolutionWorksetJoin());
DataSet<Tuple3<Double, Long, String>> result = iteration.closeWith(joined, joined.map(new NextWorksetMapper()).name(beforeNextWorksetMap));
result.output(new DiscardingOutputFormat<Tuple3<Double, Long, String>>());
result.writeAsText("/dev/null");
}
Plan p = env.createProgramPlan(jobName);
// ------------- validate the plan ----------------
assertEquals(jobName, p.getJobName());
assertEquals(defaultParallelism, p.getDefaultParallelism());
// validate the iteration
GenericDataSinkBase<?> sink1, sink2;
{
Iterator<? extends GenericDataSinkBase<?>> sinks = p.getDataSinks().iterator();
sink1 = sinks.next();
sink2 = sinks.next();
}
DeltaIterationBase<?, ?> iteration = (DeltaIterationBase<?, ?>) sink1.getInput();
// check that multi consumer translation works for iterations
assertEquals(iteration, sink2.getInput());
// check the basic iteration properties
assertEquals(numIterations, iteration.getMaximumNumberOfIterations());
assertArrayEquals(iterationKeys, iteration.getSolutionSetKeyFields());
assertEquals(iterationParallelism, iteration.getParallelism());
assertEquals(iterationName, iteration.getName());
MapOperatorBase<?, ?, ?> nextWorksetMapper = (MapOperatorBase<?, ?, ?>) iteration.getNextWorkset();
InnerJoinOperatorBase<?, ?, ?, ?> solutionSetJoin = (InnerJoinOperatorBase<?, ?, ?, ?>) iteration.getSolutionSetDelta();
InnerJoinOperatorBase<?, ?, ?, ?> worksetSelfJoin = (InnerJoinOperatorBase<?, ?, ?, ?>) solutionSetJoin.getFirstInput();
MapOperatorBase<?, ?, ?> worksetMapper = (MapOperatorBase<?, ?, ?>) worksetSelfJoin.getFirstInput();
assertEquals(IdentityMapper.class, worksetMapper.getUserCodeWrapper().getUserCodeClass());
assertEquals(NextWorksetMapper.class, nextWorksetMapper.getUserCodeWrapper().getUserCodeClass());
if (solutionSetJoin.getUserCodeWrapper().getUserCodeObject() instanceof WrappingFunction) {
WrappingFunction<?> wf = (WrappingFunction<?>) solutionSetJoin.getUserCodeWrapper().getUserCodeObject();
assertEquals(SolutionWorksetJoin.class, wf.getWrappedFunction().getClass());
} else {
assertEquals(SolutionWorksetJoin.class, solutionSetJoin.getUserCodeWrapper().getUserCodeClass());
}
assertEquals(beforeNextWorksetMap, nextWorksetMapper.getName());
assertEquals(aggregatorName, iteration.getAggregators().getAllRegisteredAggregators().iterator().next().getName());
} catch (Exception e) {
System.err.println(e.getMessage());
e.printStackTrace();
fail(e.getMessage());
}
}
use of org.apache.flink.api.common.operators.base.DeltaIterationBase in project flink by apache.
the class JavaApiPostPass method traverse.
protected void traverse(PlanNode node) {
if (!alreadyDone.add(node)) {
// already worked on that one
return;
}
// distinguish the node types
if (node instanceof SinkPlanNode) {
// descend to the input channel
SinkPlanNode sn = (SinkPlanNode) node;
Channel inchannel = sn.getInput();
traverseChannel(inchannel);
} else if (node instanceof SourcePlanNode) {
TypeInformation<?> typeInfo = getTypeInfoFromSource((SourcePlanNode) node);
((SourcePlanNode) node).setSerializer(createSerializer(typeInfo));
} else if (node instanceof BulkIterationPlanNode) {
BulkIterationPlanNode iterationNode = (BulkIterationPlanNode) node;
if (iterationNode.getRootOfStepFunction() instanceof NAryUnionPlanNode) {
throw new CompilerException("Optimizer cannot compile an iteration step function where next partial solution is created by a Union node.");
}
// utilities. Needed in case of intermediate termination criterion
if (iterationNode.getRootOfTerminationCriterion() != null) {
SingleInputPlanNode addMapper = (SingleInputPlanNode) iterationNode.getRootOfTerminationCriterion();
traverseChannel(addMapper.getInput());
}
BulkIterationBase<?> operator = (BulkIterationBase<?>) iterationNode.getProgramOperator();
// set the serializer
iterationNode.setSerializerForIterationChannel(createSerializer(operator.getOperatorInfo().getOutputType()));
// done, we can now propagate our info down
traverseChannel(iterationNode.getInput());
traverse(iterationNode.getRootOfStepFunction());
} else if (node instanceof WorksetIterationPlanNode) {
WorksetIterationPlanNode iterationNode = (WorksetIterationPlanNode) node;
if (iterationNode.getNextWorkSetPlanNode() instanceof NAryUnionPlanNode) {
throw new CompilerException("Optimizer cannot compile a workset iteration step function where the next workset is produced by a Union node.");
}
if (iterationNode.getSolutionSetDeltaPlanNode() instanceof NAryUnionPlanNode) {
throw new CompilerException("Optimizer cannot compile a workset iteration step function where the solution set delta is produced by a Union node.");
}
DeltaIterationBase<?, ?> operator = (DeltaIterationBase<?, ?>) iterationNode.getProgramOperator();
// set the serializers and comparators for the workset iteration
iterationNode.setSolutionSetSerializer(createSerializer(operator.getOperatorInfo().getFirstInputType()));
iterationNode.setWorksetSerializer(createSerializer(operator.getOperatorInfo().getSecondInputType()));
iterationNode.setSolutionSetComparator(createComparator(operator.getOperatorInfo().getFirstInputType(), iterationNode.getSolutionSetKeyFields(), getSortOrders(iterationNode.getSolutionSetKeyFields(), null)));
// traverse the inputs
traverseChannel(iterationNode.getInput1());
traverseChannel(iterationNode.getInput2());
// traverse the step function
traverse(iterationNode.getSolutionSetDeltaPlanNode());
traverse(iterationNode.getNextWorkSetPlanNode());
} else if (node instanceof SingleInputPlanNode) {
SingleInputPlanNode sn = (SingleInputPlanNode) node;
if (!(sn.getOptimizerNode().getOperator() instanceof SingleInputOperator)) {
// Special case for delta iterations
if (sn.getOptimizerNode().getOperator() instanceof NoOpUnaryUdfOp) {
traverseChannel(sn.getInput());
return;
} else {
throw new RuntimeException("Wrong operator type found in post pass.");
}
}
SingleInputOperator<?, ?, ?> singleInputOperator = (SingleInputOperator<?, ?, ?>) sn.getOptimizerNode().getOperator();
// parameterize the node's driver strategy
for (int i = 0; i < sn.getDriverStrategy().getNumRequiredComparators(); i++) {
sn.setComparator(createComparator(singleInputOperator.getOperatorInfo().getInputType(), sn.getKeys(i), getSortOrders(sn.getKeys(i), sn.getSortOrders(i))), i);
}
// done, we can now propagate our info down
traverseChannel(sn.getInput());
// don't forget the broadcast inputs
for (Channel c : sn.getBroadcastInputs()) {
traverseChannel(c);
}
} else if (node instanceof DualInputPlanNode) {
DualInputPlanNode dn = (DualInputPlanNode) node;
if (!(dn.getOptimizerNode().getOperator() instanceof DualInputOperator)) {
throw new RuntimeException("Wrong operator type found in post pass.");
}
DualInputOperator<?, ?, ?, ?> dualInputOperator = (DualInputOperator<?, ?, ?, ?>) dn.getOptimizerNode().getOperator();
// parameterize the node's driver strategy
if (dn.getDriverStrategy().getNumRequiredComparators() > 0) {
dn.setComparator1(createComparator(dualInputOperator.getOperatorInfo().getFirstInputType(), dn.getKeysForInput1(), getSortOrders(dn.getKeysForInput1(), dn.getSortOrders())));
dn.setComparator2(createComparator(dualInputOperator.getOperatorInfo().getSecondInputType(), dn.getKeysForInput2(), getSortOrders(dn.getKeysForInput2(), dn.getSortOrders())));
dn.setPairComparator(createPairComparator(dualInputOperator.getOperatorInfo().getFirstInputType(), dualInputOperator.getOperatorInfo().getSecondInputType()));
}
traverseChannel(dn.getInput1());
traverseChannel(dn.getInput2());
// don't forget the broadcast inputs
for (Channel c : dn.getBroadcastInputs()) {
traverseChannel(c);
}
} else // catch the sources of the iterative step functions
if (node instanceof BulkPartialSolutionPlanNode || node instanceof SolutionSetPlanNode || node instanceof WorksetPlanNode) {
// Do nothing :D
} else if (node instanceof NAryUnionPlanNode) {
// Traverse to all child channels
for (Channel channel : node.getInputs()) {
traverseChannel(channel);
}
} else {
throw new CompilerPostPassException("Unknown node type encountered: " + node.getClass().getName());
}
}
use of org.apache.flink.api.common.operators.base.DeltaIterationBase in project flink by apache.
the class GraphCreatingVisitor method postVisit.
@Override
public void postVisit(Operator<?> c) {
OptimizerNode n = this.con2node.get(c);
// first connect to the predecessors
n.setInput(this.con2node, this.defaultDataExchangeMode);
n.setBroadcastInputs(this.con2node, this.defaultDataExchangeMode);
// if the node represents a bulk iteration, we recursively translate the data flow now
if (n instanceof BulkIterationNode) {
final BulkIterationNode iterNode = (BulkIterationNode) n;
final BulkIterationBase<?> iter = iterNode.getIterationContract();
// pass a copy of the no iterative part into the iteration translation,
// in case the iteration references its closure
HashMap<Operator<?>, OptimizerNode> closure = new HashMap<Operator<?>, OptimizerNode>(con2node);
// first, recursively build the data flow for the step function
final GraphCreatingVisitor recursiveCreator = new GraphCreatingVisitor(this, true, iterNode.getParallelism(), defaultDataExchangeMode, closure);
BulkPartialSolutionNode partialSolution;
iter.getNextPartialSolution().accept(recursiveCreator);
partialSolution = (BulkPartialSolutionNode) recursiveCreator.con2node.get(iter.getPartialSolution());
OptimizerNode rootOfStepFunction = recursiveCreator.con2node.get(iter.getNextPartialSolution());
if (partialSolution == null) {
throw new CompilerException("Error: The step functions result does not depend on the partial solution.");
}
OptimizerNode terminationCriterion = null;
if (iter.getTerminationCriterion() != null) {
terminationCriterion = recursiveCreator.con2node.get(iter.getTerminationCriterion());
// missing parts
if (terminationCriterion == null) {
iter.getTerminationCriterion().accept(recursiveCreator);
terminationCriterion = recursiveCreator.con2node.get(iter.getTerminationCriterion());
}
}
iterNode.setPartialSolution(partialSolution);
iterNode.setNextPartialSolution(rootOfStepFunction, terminationCriterion);
// go over the contained data flow and mark the dynamic path nodes
StaticDynamicPathIdentifier identifier = new StaticDynamicPathIdentifier(iterNode.getCostWeight());
iterNode.acceptForStepFunction(identifier);
} else if (n instanceof WorksetIterationNode) {
final WorksetIterationNode iterNode = (WorksetIterationNode) n;
final DeltaIterationBase<?, ?> iter = iterNode.getIterationContract();
// we need to ensure that both the next-workset and the solution-set-delta depend on the
// workset.
// One check is for free during the translation, we do the other check here as a
// pre-condition
{
StepFunctionValidator wsf = new StepFunctionValidator();
iter.getNextWorkset().accept(wsf);
if (!wsf.hasFoundWorkset()) {
throw new CompilerException("In the given program, the next workset does not depend on the workset. " + "This is a prerequisite in delta iterations.");
}
}
// calculate the closure of the anonymous function
HashMap<Operator<?>, OptimizerNode> closure = new HashMap<Operator<?>, OptimizerNode>(con2node);
// first, recursively build the data flow for the step function
final GraphCreatingVisitor recursiveCreator = new GraphCreatingVisitor(this, true, iterNode.getParallelism(), defaultDataExchangeMode, closure);
// descend from the solution set delta. check that it depends on both the workset
// and the solution set. If it does depend on both, this descend should create both
// nodes
iter.getSolutionSetDelta().accept(recursiveCreator);
final WorksetNode worksetNode = (WorksetNode) recursiveCreator.con2node.get(iter.getWorkset());
if (worksetNode == null) {
throw new CompilerException("In the given program, the solution set delta does not depend on the workset." + "This is a prerequisite in delta iterations.");
}
iter.getNextWorkset().accept(recursiveCreator);
SolutionSetNode solutionSetNode = (SolutionSetNode) recursiveCreator.con2node.get(iter.getSolutionSet());
if (solutionSetNode == null || solutionSetNode.getOutgoingConnections() == null || solutionSetNode.getOutgoingConnections().isEmpty()) {
solutionSetNode = new SolutionSetNode((DeltaIterationBase.SolutionSetPlaceHolder<?>) iter.getSolutionSet(), iterNode);
} else {
for (DagConnection conn : solutionSetNode.getOutgoingConnections()) {
OptimizerNode successor = conn.getTarget();
if (successor.getClass() == JoinNode.class) {
// find out which input to the match the solution set is
JoinNode mn = (JoinNode) successor;
if (mn.getFirstPredecessorNode() == solutionSetNode) {
mn.makeJoinWithSolutionSet(0);
} else if (mn.getSecondPredecessorNode() == solutionSetNode) {
mn.makeJoinWithSolutionSet(1);
} else {
throw new CompilerException();
}
} else if (successor.getClass() == CoGroupNode.class) {
CoGroupNode cg = (CoGroupNode) successor;
if (cg.getFirstPredecessorNode() == solutionSetNode) {
cg.makeCoGroupWithSolutionSet(0);
} else if (cg.getSecondPredecessorNode() == solutionSetNode) {
cg.makeCoGroupWithSolutionSet(1);
} else {
throw new CompilerException();
}
} else {
throw new InvalidProgramException("Error: The only operations allowed on the solution set are Join and CoGroup.");
}
}
}
final OptimizerNode nextWorksetNode = recursiveCreator.con2node.get(iter.getNextWorkset());
final OptimizerNode solutionSetDeltaNode = recursiveCreator.con2node.get(iter.getSolutionSetDelta());
// set the step function nodes to the iteration node
iterNode.setPartialSolution(solutionSetNode, worksetNode);
iterNode.setNextPartialSolution(solutionSetDeltaNode, nextWorksetNode, defaultDataExchangeMode);
// go over the contained data flow and mark the dynamic path nodes
StaticDynamicPathIdentifier pathIdentifier = new StaticDynamicPathIdentifier(iterNode.getCostWeight());
iterNode.acceptForStepFunction(pathIdentifier);
}
}
use of org.apache.flink.api.common.operators.base.DeltaIterationBase in project flink by apache.
the class GraphCreatingVisitor method preVisit.
@SuppressWarnings("deprecation")
@Override
public boolean preVisit(Operator<?> c) {
// check if we have been here before
if (this.con2node.containsKey(c)) {
return false;
}
final OptimizerNode n;
// create a node for the operator (or sink or source) if we have not been here before
if (c instanceof GenericDataSinkBase) {
DataSinkNode dsn = new DataSinkNode((GenericDataSinkBase<?>) c);
this.sinks.add(dsn);
n = dsn;
} else if (c instanceof GenericDataSourceBase) {
n = new DataSourceNode((GenericDataSourceBase<?, ?>) c);
} else if (c instanceof MapOperatorBase) {
n = new MapNode((MapOperatorBase<?, ?, ?>) c);
} else if (c instanceof MapPartitionOperatorBase) {
n = new MapPartitionNode((MapPartitionOperatorBase<?, ?, ?>) c);
} else if (c instanceof FlatMapOperatorBase) {
n = new FlatMapNode((FlatMapOperatorBase<?, ?, ?>) c);
} else if (c instanceof FilterOperatorBase) {
n = new FilterNode((FilterOperatorBase<?, ?>) c);
} else if (c instanceof ReduceOperatorBase) {
n = new ReduceNode((ReduceOperatorBase<?, ?>) c);
} else if (c instanceof GroupCombineOperatorBase) {
n = new GroupCombineNode((GroupCombineOperatorBase<?, ?, ?>) c);
} else if (c instanceof GroupReduceOperatorBase) {
n = new GroupReduceNode((GroupReduceOperatorBase<?, ?, ?>) c);
} else if (c instanceof InnerJoinOperatorBase) {
n = new JoinNode((InnerJoinOperatorBase<?, ?, ?, ?>) c);
} else if (c instanceof OuterJoinOperatorBase) {
n = new OuterJoinNode((OuterJoinOperatorBase<?, ?, ?, ?>) c);
} else if (c instanceof CoGroupOperatorBase) {
n = new CoGroupNode((CoGroupOperatorBase<?, ?, ?, ?>) c);
} else if (c instanceof CoGroupRawOperatorBase) {
n = new CoGroupRawNode((CoGroupRawOperatorBase<?, ?, ?, ?>) c);
} else if (c instanceof CrossOperatorBase) {
n = new CrossNode((CrossOperatorBase<?, ?, ?, ?>) c);
} else if (c instanceof BulkIterationBase) {
n = new BulkIterationNode((BulkIterationBase<?>) c);
} else if (c instanceof DeltaIterationBase) {
n = new WorksetIterationNode((DeltaIterationBase<?, ?>) c);
} else if (c instanceof Union) {
n = new BinaryUnionNode((Union<?>) c);
} else if (c instanceof PartitionOperatorBase) {
n = new PartitionNode((PartitionOperatorBase<?>) c);
} else if (c instanceof SortPartitionOperatorBase) {
n = new SortPartitionNode((SortPartitionOperatorBase<?>) c);
} else if (c instanceof BulkIterationBase.PartialSolutionPlaceHolder) {
if (this.parent == null) {
throw new InvalidProgramException("It is currently not supported to create data sinks inside iterations.");
}
final BulkIterationBase.PartialSolutionPlaceHolder<?> holder = (BulkIterationBase.PartialSolutionPlaceHolder<?>) c;
final BulkIterationBase<?> enclosingIteration = holder.getContainingBulkIteration();
final BulkIterationNode containingIterationNode = (BulkIterationNode) this.parent.con2node.get(enclosingIteration);
// catch this for the recursive translation of step functions
BulkPartialSolutionNode p = new BulkPartialSolutionNode(holder, containingIterationNode);
p.setParallelism(containingIterationNode.getParallelism());
n = p;
} else if (c instanceof DeltaIterationBase.WorksetPlaceHolder) {
if (this.parent == null) {
throw new InvalidProgramException("It is currently not supported to create data sinks inside iterations.");
}
final DeltaIterationBase.WorksetPlaceHolder<?> holder = (DeltaIterationBase.WorksetPlaceHolder<?>) c;
final DeltaIterationBase<?, ?> enclosingIteration = holder.getContainingWorksetIteration();
final WorksetIterationNode containingIterationNode = (WorksetIterationNode) this.parent.con2node.get(enclosingIteration);
// catch this for the recursive translation of step functions
WorksetNode p = new WorksetNode(holder, containingIterationNode);
p.setParallelism(containingIterationNode.getParallelism());
n = p;
} else if (c instanceof DeltaIterationBase.SolutionSetPlaceHolder) {
if (this.parent == null) {
throw new InvalidProgramException("It is currently not supported to create data sinks inside iterations.");
}
final DeltaIterationBase.SolutionSetPlaceHolder<?> holder = (DeltaIterationBase.SolutionSetPlaceHolder<?>) c;
final DeltaIterationBase<?, ?> enclosingIteration = holder.getContainingWorksetIteration();
final WorksetIterationNode containingIterationNode = (WorksetIterationNode) this.parent.con2node.get(enclosingIteration);
// catch this for the recursive translation of step functions
SolutionSetNode p = new SolutionSetNode(holder, containingIterationNode);
p.setParallelism(containingIterationNode.getParallelism());
n = p;
} else {
throw new IllegalArgumentException("Unknown operator type: " + c);
}
this.con2node.put(c, n);
// key-less reducer (all-reduce)
if (n.getParallelism() < 1) {
// set the parallelism
int par = c.getParallelism();
if (n instanceof BinaryUnionNode) {
// Keep parallelism of union undefined for now.
// It will be determined based on the parallelism of its successor.
par = -1;
} else if (par > 0) {
if (this.forceParallelism && par != this.defaultParallelism) {
par = this.defaultParallelism;
Optimizer.LOG.warn("The parallelism of nested dataflows (such as step functions in iterations) is " + "currently fixed to the parallelism of the surrounding operator (the iteration).");
}
} else {
par = this.defaultParallelism;
}
n.setParallelism(par);
}
return true;
}
Aggregations