Search in sources :

Example 1 with AggregatorRegistry

use of org.apache.flink.api.common.aggregators.AggregatorRegistry in project flink by apache.

the class JobGraphGenerator method finalizeWorksetIteration.

private void finalizeWorksetIteration(IterationDescriptor descr) {
    final WorksetIterationPlanNode iterNode = (WorksetIterationPlanNode) descr.getIterationNode();
    final JobVertex headVertex = descr.getHeadTask();
    final TaskConfig headConfig = new TaskConfig(headVertex.getConfiguration());
    final TaskConfig headFinalOutputConfig = descr.getHeadFinalResultConfig();
    // ------------ finalize the head config with the final outputs and the sync gate ------------
    {
        final int numStepFunctionOuts = headConfig.getNumOutputs();
        final int numFinalOuts = headFinalOutputConfig.getNumOutputs();
        if (numStepFunctionOuts == 0) {
            throw new CompilerException("The workset iteration has no operation on the workset inside the step function.");
        }
        headConfig.setIterationHeadFinalOutputConfig(headFinalOutputConfig);
        headConfig.setIterationHeadIndexOfSyncOutput(numStepFunctionOuts + numFinalOuts);
        final double relativeMemory = iterNode.getRelativeMemoryPerSubTask();
        if (relativeMemory <= 0) {
            throw new CompilerException("Bug: No memory has been assigned to the workset iteration.");
        }
        headConfig.setIsWorksetIteration();
        headConfig.setRelativeBackChannelMemory(relativeMemory / 2);
        headConfig.setRelativeSolutionSetMemory(relativeMemory / 2);
        // set the solution set serializer and comparator
        headConfig.setSolutionSetSerializer(iterNode.getSolutionSetSerializer());
        headConfig.setSolutionSetComparator(iterNode.getSolutionSetComparator());
    }
    // --------------------------- create the sync task ---------------------------
    final TaskConfig syncConfig;
    {
        final JobVertex sync = new JobVertex("Sync (" + iterNode.getNodeName() + ")");
        sync.setResources(iterNode.getMinResources(), iterNode.getPreferredResources());
        sync.setInvokableClass(IterationSynchronizationSinkTask.class);
        sync.setParallelism(1);
        sync.setMaxParallelism(1);
        this.auxVertices.add(sync);
        syncConfig = new TaskConfig(sync.getConfiguration());
        syncConfig.setGateIterativeWithNumberOfEventsUntilInterrupt(0, headVertex.getParallelism());
        // set the number of iteration / convergence criterion for the sync
        final int maxNumIterations = iterNode.getIterationNode().getIterationContract().getMaximumNumberOfIterations();
        if (maxNumIterations < 1) {
            throw new CompilerException("Cannot create workset iteration with unspecified maximum number of iterations.");
        }
        syncConfig.setNumberOfIterations(maxNumIterations);
        // connect the sync task
        sync.connectNewDataSetAsInput(headVertex, DistributionPattern.POINTWISE, ResultPartitionType.PIPELINED);
    }
    // ----------------------------- create the iteration tails -----------------------------
    // ----------------------- for next workset and solution set delta-----------------------
    {
        // we have three possible cases:
        // 1) Two tails, one for workset update, one for solution set update
        // 2) One tail for workset update, solution set update happens in an intermediate task
        // 3) One tail for solution set update, workset update happens in an intermediate task
        final PlanNode nextWorksetNode = iterNode.getNextWorkSetPlanNode();
        final PlanNode solutionDeltaNode = iterNode.getSolutionSetDeltaPlanNode();
        final boolean hasWorksetTail = nextWorksetNode.getOutgoingChannels().isEmpty();
        final boolean hasSolutionSetTail = (!iterNode.isImmediateSolutionSetUpdate()) || (!hasWorksetTail);
        {
            // get the vertex for the workset update
            final TaskConfig worksetTailConfig;
            JobVertex nextWorksetVertex = this.vertices.get(nextWorksetNode);
            if (nextWorksetVertex == null) {
                // nextWorksetVertex is chained
                TaskInChain taskInChain = this.chainedTasks.get(nextWorksetNode);
                if (taskInChain == null) {
                    throw new CompilerException("Bug: Next workset node not found as vertex or chained task.");
                }
                nextWorksetVertex = taskInChain.getContainingVertex();
                worksetTailConfig = taskInChain.getTaskConfig();
            } else {
                worksetTailConfig = new TaskConfig(nextWorksetVertex.getConfiguration());
            }
            // mark the node to perform workset updates
            worksetTailConfig.setIsWorksetIteration();
            worksetTailConfig.setIsWorksetUpdate();
            if (hasWorksetTail) {
                nextWorksetVertex.setInvokableClass(IterationTailTask.class);
                worksetTailConfig.setOutputSerializer(iterNode.getWorksetSerializer());
            }
        }
        {
            final TaskConfig solutionDeltaConfig;
            JobVertex solutionDeltaVertex = this.vertices.get(solutionDeltaNode);
            if (solutionDeltaVertex == null) {
                // last op is chained
                TaskInChain taskInChain = this.chainedTasks.get(solutionDeltaNode);
                if (taskInChain == null) {
                    throw new CompilerException("Bug: Solution Set Delta not found as vertex or chained task.");
                }
                solutionDeltaVertex = taskInChain.getContainingVertex();
                solutionDeltaConfig = taskInChain.getTaskConfig();
            } else {
                solutionDeltaConfig = new TaskConfig(solutionDeltaVertex.getConfiguration());
            }
            solutionDeltaConfig.setIsWorksetIteration();
            solutionDeltaConfig.setIsSolutionSetUpdate();
            if (hasSolutionSetTail) {
                solutionDeltaVertex.setInvokableClass(IterationTailTask.class);
                solutionDeltaConfig.setOutputSerializer(iterNode.getSolutionSetSerializer());
                // tell the head that it needs to wait for the solution set updates
                headConfig.setWaitForSolutionSetUpdate();
            } else {
                // no tail, intermediate update. must be immediate update
                if (!iterNode.isImmediateSolutionSetUpdate()) {
                    throw new CompilerException("A solution set update without dedicated tail is not set to perform immediate updates.");
                }
                solutionDeltaConfig.setIsSolutionSetUpdateWithoutReprobe();
            }
        }
    }
    // ------------------- register the aggregators -------------------
    AggregatorRegistry aggs = iterNode.getIterationNode().getIterationContract().getAggregators();
    Collection<AggregatorWithName<?>> allAggregators = aggs.getAllRegisteredAggregators();
    for (AggregatorWithName<?> agg : allAggregators) {
        if (agg.getName().equals(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME)) {
            throw new CompilerException("User defined aggregator used the same name as built-in workset " + "termination check aggregator: " + WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME);
        }
    }
    headConfig.addIterationAggregators(allAggregators);
    syncConfig.addIterationAggregators(allAggregators);
    String convAggName = aggs.getConvergenceCriterionAggregatorName();
    ConvergenceCriterion<?> convCriterion = aggs.getConvergenceCriterion();
    if (convCriterion != null || convAggName != null) {
        if (convCriterion == null) {
            throw new CompilerException("Error: Convergence criterion aggregator set, but criterion is null.");
        }
        if (convAggName == null) {
            throw new CompilerException("Error: Aggregator convergence criterion set, but aggregator is null.");
        }
        syncConfig.setConvergenceCriterion(convAggName, convCriterion);
    }
    headConfig.addIterationAggregator(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new LongSumAggregator());
    syncConfig.addIterationAggregator(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new LongSumAggregator());
    syncConfig.setImplicitConvergenceCriterion(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new WorksetEmptyConvergenceCriterion());
}
Also used : WorksetIterationPlanNode(org.apache.flink.optimizer.plan.WorksetIterationPlanNode) LongSumAggregator(org.apache.flink.api.common.aggregators.LongSumAggregator) TaskConfig(org.apache.flink.runtime.operators.util.TaskConfig) JobVertex(org.apache.flink.runtime.jobgraph.JobVertex) SolutionSetPlanNode(org.apache.flink.optimizer.plan.SolutionSetPlanNode) IterationPlanNode(org.apache.flink.optimizer.plan.IterationPlanNode) BulkIterationPlanNode(org.apache.flink.optimizer.plan.BulkIterationPlanNode) WorksetPlanNode(org.apache.flink.optimizer.plan.WorksetPlanNode) SingleInputPlanNode(org.apache.flink.optimizer.plan.SingleInputPlanNode) WorksetIterationPlanNode(org.apache.flink.optimizer.plan.WorksetIterationPlanNode) SourcePlanNode(org.apache.flink.optimizer.plan.SourcePlanNode) BulkPartialSolutionPlanNode(org.apache.flink.optimizer.plan.BulkPartialSolutionPlanNode) DualInputPlanNode(org.apache.flink.optimizer.plan.DualInputPlanNode) PlanNode(org.apache.flink.optimizer.plan.PlanNode) SinkPlanNode(org.apache.flink.optimizer.plan.SinkPlanNode) NAryUnionPlanNode(org.apache.flink.optimizer.plan.NAryUnionPlanNode) IterationSynchronizationSinkTask(org.apache.flink.runtime.iterative.task.IterationSynchronizationSinkTask) WorksetEmptyConvergenceCriterion(org.apache.flink.runtime.iterative.convergence.WorksetEmptyConvergenceCriterion) CompilerException(org.apache.flink.optimizer.CompilerException) AggregatorWithName(org.apache.flink.api.common.aggregators.AggregatorWithName) AggregatorRegistry(org.apache.flink.api.common.aggregators.AggregatorRegistry)

Example 2 with AggregatorRegistry

use of org.apache.flink.api.common.aggregators.AggregatorRegistry in project flink by apache.

the class JobGraphGenerator method finalizeBulkIteration.

private void finalizeBulkIteration(IterationDescriptor descr) {
    final BulkIterationPlanNode bulkNode = (BulkIterationPlanNode) descr.getIterationNode();
    final JobVertex headVertex = descr.getHeadTask();
    final TaskConfig headConfig = new TaskConfig(headVertex.getConfiguration());
    final TaskConfig headFinalOutputConfig = descr.getHeadFinalResultConfig();
    // ------------ finalize the head config with the final outputs and the sync gate ------------
    final int numStepFunctionOuts = headConfig.getNumOutputs();
    final int numFinalOuts = headFinalOutputConfig.getNumOutputs();
    if (numStepFunctionOuts == 0) {
        throw new CompilerException("The iteration has no operation inside the step function.");
    }
    headConfig.setIterationHeadFinalOutputConfig(headFinalOutputConfig);
    headConfig.setIterationHeadIndexOfSyncOutput(numStepFunctionOuts + numFinalOuts);
    final double relativeMemForBackChannel = bulkNode.getRelativeMemoryPerSubTask();
    if (relativeMemForBackChannel <= 0) {
        throw new CompilerException("Bug: No memory has been assigned to the iteration back channel.");
    }
    headConfig.setRelativeBackChannelMemory(relativeMemForBackChannel);
    // --------------------------- create the sync task ---------------------------
    final JobVertex sync = new JobVertex("Sync(" + bulkNode.getNodeName() + ")");
    sync.setResources(bulkNode.getMinResources(), bulkNode.getPreferredResources());
    sync.setInvokableClass(IterationSynchronizationSinkTask.class);
    sync.setParallelism(1);
    sync.setMaxParallelism(1);
    this.auxVertices.add(sync);
    final TaskConfig syncConfig = new TaskConfig(sync.getConfiguration());
    syncConfig.setGateIterativeWithNumberOfEventsUntilInterrupt(0, headVertex.getParallelism());
    // set the number of iteration / convergence criterion for the sync
    final int maxNumIterations = bulkNode.getIterationNode().getIterationContract().getMaximumNumberOfIterations();
    if (maxNumIterations < 1) {
        throw new CompilerException("Cannot create bulk iteration with unspecified maximum number of iterations.");
    }
    syncConfig.setNumberOfIterations(maxNumIterations);
    // connect the sync task
    sync.connectNewDataSetAsInput(headVertex, DistributionPattern.POINTWISE, ResultPartitionType.PIPELINED);
    // ----------------------------- create the iteration tail ------------------------------
    final PlanNode rootOfTerminationCriterion = bulkNode.getRootOfTerminationCriterion();
    final PlanNode rootOfStepFunction = bulkNode.getRootOfStepFunction();
    final TaskConfig tailConfig;
    JobVertex rootOfStepFunctionVertex = this.vertices.get(rootOfStepFunction);
    if (rootOfStepFunctionVertex == null) {
        // last op is chained
        final TaskInChain taskInChain = this.chainedTasks.get(rootOfStepFunction);
        if (taskInChain == null) {
            throw new CompilerException("Bug: Tail of step function not found as vertex or chained task.");
        }
        rootOfStepFunctionVertex = taskInChain.getContainingVertex();
        // the fake channel is statically typed to pact record. no data is sent over this channel anyways.
        tailConfig = taskInChain.getTaskConfig();
    } else {
        tailConfig = new TaskConfig(rootOfStepFunctionVertex.getConfiguration());
    }
    tailConfig.setIsWorksetUpdate();
    // No following termination criterion
    if (rootOfStepFunction.getOutgoingChannels().isEmpty()) {
        rootOfStepFunctionVertex.setInvokableClass(IterationTailTask.class);
        tailConfig.setOutputSerializer(bulkNode.getSerializerForIterationChannel());
    }
    // create the fake output task for termination criterion, if needed
    final TaskConfig tailConfigOfTerminationCriterion;
    // If we have a termination criterion and it is not an intermediate node
    if (rootOfTerminationCriterion != null && rootOfTerminationCriterion.getOutgoingChannels().isEmpty()) {
        JobVertex rootOfTerminationCriterionVertex = this.vertices.get(rootOfTerminationCriterion);
        if (rootOfTerminationCriterionVertex == null) {
            // last op is chained
            final TaskInChain taskInChain = this.chainedTasks.get(rootOfTerminationCriterion);
            if (taskInChain == null) {
                throw new CompilerException("Bug: Tail of termination criterion not found as vertex or chained task.");
            }
            rootOfTerminationCriterionVertex = taskInChain.getContainingVertex();
            // the fake channel is statically typed to pact record. no data is sent over this channel anyways.
            tailConfigOfTerminationCriterion = taskInChain.getTaskConfig();
        } else {
            tailConfigOfTerminationCriterion = new TaskConfig(rootOfTerminationCriterionVertex.getConfiguration());
        }
        rootOfTerminationCriterionVertex.setInvokableClass(IterationTailTask.class);
        // Hack
        tailConfigOfTerminationCriterion.setIsSolutionSetUpdate();
        tailConfigOfTerminationCriterion.setOutputSerializer(bulkNode.getSerializerForIterationChannel());
        // tell the head that it needs to wait for the solution set updates
        headConfig.setWaitForSolutionSetUpdate();
    }
    // ------------------- register the aggregators -------------------
    AggregatorRegistry aggs = bulkNode.getIterationNode().getIterationContract().getAggregators();
    Collection<AggregatorWithName<?>> allAggregators = aggs.getAllRegisteredAggregators();
    headConfig.addIterationAggregators(allAggregators);
    syncConfig.addIterationAggregators(allAggregators);
    String convAggName = aggs.getConvergenceCriterionAggregatorName();
    ConvergenceCriterion<?> convCriterion = aggs.getConvergenceCriterion();
    if (convCriterion != null || convAggName != null) {
        if (convCriterion == null) {
            throw new CompilerException("Error: Convergence criterion aggregator set, but criterion is null.");
        }
        if (convAggName == null) {
            throw new CompilerException("Error: Aggregator convergence criterion set, but aggregator is null.");
        }
        syncConfig.setConvergenceCriterion(convAggName, convCriterion);
    }
}
Also used : TaskConfig(org.apache.flink.runtime.operators.util.TaskConfig) JobVertex(org.apache.flink.runtime.jobgraph.JobVertex) SolutionSetPlanNode(org.apache.flink.optimizer.plan.SolutionSetPlanNode) IterationPlanNode(org.apache.flink.optimizer.plan.IterationPlanNode) BulkIterationPlanNode(org.apache.flink.optimizer.plan.BulkIterationPlanNode) WorksetPlanNode(org.apache.flink.optimizer.plan.WorksetPlanNode) SingleInputPlanNode(org.apache.flink.optimizer.plan.SingleInputPlanNode) WorksetIterationPlanNode(org.apache.flink.optimizer.plan.WorksetIterationPlanNode) SourcePlanNode(org.apache.flink.optimizer.plan.SourcePlanNode) BulkPartialSolutionPlanNode(org.apache.flink.optimizer.plan.BulkPartialSolutionPlanNode) DualInputPlanNode(org.apache.flink.optimizer.plan.DualInputPlanNode) PlanNode(org.apache.flink.optimizer.plan.PlanNode) SinkPlanNode(org.apache.flink.optimizer.plan.SinkPlanNode) NAryUnionPlanNode(org.apache.flink.optimizer.plan.NAryUnionPlanNode) CompilerException(org.apache.flink.optimizer.CompilerException) AggregatorWithName(org.apache.flink.api.common.aggregators.AggregatorWithName) BulkIterationPlanNode(org.apache.flink.optimizer.plan.BulkIterationPlanNode) AggregatorRegistry(org.apache.flink.api.common.aggregators.AggregatorRegistry)

Aggregations

AggregatorRegistry (org.apache.flink.api.common.aggregators.AggregatorRegistry)2 AggregatorWithName (org.apache.flink.api.common.aggregators.AggregatorWithName)2 CompilerException (org.apache.flink.optimizer.CompilerException)2 BulkIterationPlanNode (org.apache.flink.optimizer.plan.BulkIterationPlanNode)2 BulkPartialSolutionPlanNode (org.apache.flink.optimizer.plan.BulkPartialSolutionPlanNode)2 DualInputPlanNode (org.apache.flink.optimizer.plan.DualInputPlanNode)2 IterationPlanNode (org.apache.flink.optimizer.plan.IterationPlanNode)2 NAryUnionPlanNode (org.apache.flink.optimizer.plan.NAryUnionPlanNode)2 PlanNode (org.apache.flink.optimizer.plan.PlanNode)2 SingleInputPlanNode (org.apache.flink.optimizer.plan.SingleInputPlanNode)2 SinkPlanNode (org.apache.flink.optimizer.plan.SinkPlanNode)2 SolutionSetPlanNode (org.apache.flink.optimizer.plan.SolutionSetPlanNode)2 SourcePlanNode (org.apache.flink.optimizer.plan.SourcePlanNode)2 WorksetIterationPlanNode (org.apache.flink.optimizer.plan.WorksetIterationPlanNode)2 WorksetPlanNode (org.apache.flink.optimizer.plan.WorksetPlanNode)2 JobVertex (org.apache.flink.runtime.jobgraph.JobVertex)2 TaskConfig (org.apache.flink.runtime.operators.util.TaskConfig)2 LongSumAggregator (org.apache.flink.api.common.aggregators.LongSumAggregator)1 WorksetEmptyConvergenceCriterion (org.apache.flink.runtime.iterative.convergence.WorksetEmptyConvergenceCriterion)1 IterationSynchronizationSinkTask (org.apache.flink.runtime.iterative.task.IterationSynchronizationSinkTask)1