Search in sources :

Example 1 with MatrixVectorBinaryOpPartitionFunction

use of org.apache.sysml.runtime.instructions.spark.functions.MatrixVectorBinaryOpPartitionFunction in project incubator-systemml by apache.

the class BinarySPInstruction method processMatrixBVectorBinaryInstruction.

protected void processMatrixBVectorBinaryInstruction(ExecutionContext ec, VectorType vtype) throws DMLRuntimeException {
    SparkExecutionContext sec = (SparkExecutionContext) ec;
    //sanity check dimensions
    checkMatrixMatrixBinaryCharacteristics(sec);
    //get input RDDs
    String rddVar = input1.getName();
    String bcastVar = input2.getName();
    JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(rddVar);
    PartitionedBroadcast<MatrixBlock> in2 = sec.getBroadcastForVariable(bcastVar);
    MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(rddVar);
    MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(bcastVar);
    BinaryOperator bop = (BinaryOperator) _optr;
    boolean isOuter = (mc1.getRows() > 1 && mc1.getCols() == 1 && mc2.getRows() == 1 && mc2.getCols() > 1);
    //execute map binary operation
    JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
    if (isOuter) {
        out = in1.flatMapToPair(new OuterVectorBinaryOpFunction(bop, in2));
    } else {
        //default
        //note: we use mappartition in order to preserve partitioning information for
        //binary mv operations where the keys are guaranteed not to change, the reason
        //why we cannot use mapValues is the need for broadcast key lookups.
        //alternative: out = in1.mapToPair(new MatrixVectorBinaryOpFunction(bop, in2, vtype));
        out = in1.mapPartitionsToPair(new MatrixVectorBinaryOpPartitionFunction(bop, in2, vtype), true);
    }
    //set output RDD
    updateBinaryOutputMatrixCharacteristics(sec);
    sec.setRDDHandleForVariable(output.getName(), out);
    sec.addLineageRDD(output.getName(), rddVar);
    sec.addLineageBroadcast(output.getName(), bcastVar);
}
Also used : MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) MatrixVectorBinaryOpPartitionFunction(org.apache.sysml.runtime.instructions.spark.functions.MatrixVectorBinaryOpPartitionFunction) SparkExecutionContext(org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext) BinaryOperator(org.apache.sysml.runtime.matrix.operators.BinaryOperator) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) OuterVectorBinaryOpFunction(org.apache.sysml.runtime.instructions.spark.functions.OuterVectorBinaryOpFunction)

Aggregations

SparkExecutionContext (org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext)1 MatrixVectorBinaryOpPartitionFunction (org.apache.sysml.runtime.instructions.spark.functions.MatrixVectorBinaryOpPartitionFunction)1 OuterVectorBinaryOpFunction (org.apache.sysml.runtime.instructions.spark.functions.OuterVectorBinaryOpFunction)1 MatrixCharacteristics (org.apache.sysml.runtime.matrix.MatrixCharacteristics)1 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)1 MatrixIndexes (org.apache.sysml.runtime.matrix.data.MatrixIndexes)1 BinaryOperator (org.apache.sysml.runtime.matrix.operators.BinaryOperator)1