Search in sources :

Example 6 with AggregateDropCorrectionFunction

use of org.apache.sysml.runtime.instructions.spark.functions.AggregateDropCorrectionFunction in project systemml by apache.

the class UaggOuterChainSPInstruction method processInstruction.

@Override
public void processInstruction(ExecutionContext ec) {
    SparkExecutionContext sec = (SparkExecutionContext) ec;
    boolean rightCached = (_uaggOp.indexFn instanceof ReduceCol || _uaggOp.indexFn instanceof ReduceAll || !LibMatrixOuterAgg.isSupportedUaggOp(_uaggOp, _bOp));
    String rddVar = (rightCached) ? input1.getName() : input2.getName();
    String bcastVar = (rightCached) ? input2.getName() : input1.getName();
    // get rdd input
    JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(rddVar);
    MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(rddVar);
    boolean noKeyChange = preservesPartitioning(mcIn, _uaggOp.indexFn);
    // execute UAggOuterChain instruction
    JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
    if (LibMatrixOuterAgg.isSupportedUaggOp(_uaggOp, _bOp)) {
        // create sorted broadcast matrix
        MatrixBlock mb = sec.getMatrixInput(bcastVar, getExtendedOpcode());
        sec.releaseMatrixInput(bcastVar, getExtendedOpcode());
        // prevent lineage tracking
        bcastVar = null;
        double[] vmb = DataConverter.convertToDoubleVector(mb);
        Broadcast<int[]> bvi = null;
        if (_uaggOp.aggOp.increOp.fn instanceof Builtin) {
            int[] vix = LibMatrixOuterAgg.prepareRowIndices(mb.getNumColumns(), vmb, _bOp, _uaggOp);
            bvi = sec.getSparkContext().broadcast(vix);
        } else
            Arrays.sort(vmb);
        Broadcast<double[]> bv = sec.getSparkContext().broadcast(vmb);
        // partitioning-preserving map-to-pair (under constraints)
        out = in1.mapPartitionsToPair(new RDDMapUAggOuterChainFunction(bv, bvi, _bOp, _uaggOp), noKeyChange);
    } else {
        PartitionedBroadcast<MatrixBlock> bv = sec.getBroadcastForVariable(bcastVar);
        // partitioning-preserving map-to-pair (under constraints)
        out = in1.mapPartitionsToPair(new RDDMapGenUAggOuterChainFunction(bv, _uaggOp, _aggOp, _bOp, mcIn), noKeyChange);
    }
    // final aggregation if required
    if (// RC AGG (output is scalar)
    _uaggOp.indexFn instanceof ReduceAll) {
        MatrixBlock tmp = RDDAggregateUtils.aggStable(out, _aggOp);
        // drop correction after aggregation
        tmp.dropLastRowsOrColumns(_aggOp.correctionLocation);
        // put output block into symbol table (no lineage because single block)
        sec.setMatrixOutput(output.getName(), tmp, getExtendedOpcode());
    } else // R/C AGG (output is rdd)
    {
        // put output RDD handle into symbol table
        updateUnaryAggOutputMatrixCharacteristics(sec);
        if (_uaggOp.aggOp.correctionExists)
            out = out.mapValues(new AggregateDropCorrectionFunction(_uaggOp.aggOp));
        sec.setRDDHandleForVariable(output.getName(), out);
        sec.addLineageRDD(output.getName(), rddVar);
        if (bcastVar != null)
            sec.addLineageBroadcast(output.getName(), bcastVar);
    }
}
Also used : ReduceCol(org.apache.sysml.runtime.functionobjects.ReduceCol) ReduceAll(org.apache.sysml.runtime.functionobjects.ReduceAll) MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) AggregateDropCorrectionFunction(org.apache.sysml.runtime.instructions.spark.functions.AggregateDropCorrectionFunction) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) SparkExecutionContext(org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext) Builtin(org.apache.sysml.runtime.functionobjects.Builtin)

Aggregations

SparkExecutionContext (org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext)6 AggregateDropCorrectionFunction (org.apache.sysml.runtime.instructions.spark.functions.AggregateDropCorrectionFunction)6 MatrixCharacteristics (org.apache.sysml.runtime.matrix.MatrixCharacteristics)6 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)6 MatrixIndexes (org.apache.sysml.runtime.matrix.data.MatrixIndexes)6 ReduceAll (org.apache.sysml.runtime.functionobjects.ReduceAll)4 Builtin (org.apache.sysml.runtime.functionobjects.Builtin)2 ReduceCol (org.apache.sysml.runtime.functionobjects.ReduceCol)2 DoubleObject (org.apache.sysml.runtime.instructions.cp.DoubleObject)2 FilterDiagBlocksFunction (org.apache.sysml.runtime.instructions.spark.functions.FilterDiagBlocksFunction)2 AggregateOperator (org.apache.sysml.runtime.matrix.operators.AggregateOperator)2 AggregateTernaryOperator (org.apache.sysml.runtime.matrix.operators.AggregateTernaryOperator)2 AggregateUnaryOperator (org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator)2