Search in sources :

Example 1 with ReplicateBlockFunction

use of org.apache.sysml.runtime.instructions.spark.functions.ReplicateBlockFunction in project incubator-systemml by apache.

the class QuaternarySPInstruction method processInstruction.

@Override
public void processInstruction(ExecutionContext ec) {
    SparkExecutionContext sec = (SparkExecutionContext) ec;
    QuaternaryOperator qop = (QuaternaryOperator) _optr;
    // tracking of rdds and broadcasts (for lineage maintenance)
    ArrayList<String> rddVars = new ArrayList<>();
    ArrayList<String> bcVars = new ArrayList<>();
    JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable(input1.getName());
    JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
    MatrixCharacteristics inMc = sec.getMatrixCharacteristics(input1.getName());
    long rlen = inMc.getRows();
    long clen = inMc.getCols();
    int brlen = inMc.getRowsPerBlock();
    int bclen = inMc.getColsPerBlock();
    // (map/redwsloss, map/redwcemm); safe because theses ops produce a scalar
    if (qop.wtype1 != null || qop.wtype4 != null) {
        in = in.filter(new FilterNonEmptyBlocksFunction());
    }
    // map-side only operation (one rdd input, two broadcasts)
    if (WeightedSquaredLoss.OPCODE.equalsIgnoreCase(getOpcode()) || WeightedSigmoid.OPCODE.equalsIgnoreCase(getOpcode()) || WeightedDivMM.OPCODE.equalsIgnoreCase(getOpcode()) || WeightedCrossEntropy.OPCODE.equalsIgnoreCase(getOpcode()) || WeightedUnaryMM.OPCODE.equalsIgnoreCase(getOpcode())) {
        PartitionedBroadcast<MatrixBlock> bc1 = sec.getBroadcastForVariable(input2.getName());
        PartitionedBroadcast<MatrixBlock> bc2 = sec.getBroadcastForVariable(input3.getName());
        // partitioning-preserving mappartitions (key access required for broadcast loopkup)
        // only wdivmm changes keys
        boolean noKeyChange = (qop.wtype3 == null || qop.wtype3.isBasic());
        out = in.mapPartitionsToPair(new RDDQuaternaryFunction1(qop, bc1, bc2), noKeyChange);
        rddVars.add(input1.getName());
        bcVars.add(input2.getName());
        bcVars.add(input3.getName());
    } else // reduce-side operation (two/three/four rdd inputs, zero/one/two broadcasts)
    {
        PartitionedBroadcast<MatrixBlock> bc1 = _cacheU ? sec.getBroadcastForVariable(input2.getName()) : null;
        PartitionedBroadcast<MatrixBlock> bc2 = _cacheV ? sec.getBroadcastForVariable(input3.getName()) : null;
        JavaPairRDD<MatrixIndexes, MatrixBlock> inU = (!_cacheU) ? sec.getBinaryBlockRDDHandleForVariable(input2.getName()) : null;
        JavaPairRDD<MatrixIndexes, MatrixBlock> inV = (!_cacheV) ? sec.getBinaryBlockRDDHandleForVariable(input3.getName()) : null;
        JavaPairRDD<MatrixIndexes, MatrixBlock> inW = (qop.hasFourInputs() && !_input4.isLiteral()) ? sec.getBinaryBlockRDDHandleForVariable(_input4.getName()) : null;
        // preparation of transposed and replicated U
        if (inU != null)
            inU = inU.flatMapToPair(new ReplicateBlockFunction(clen, bclen, true));
        // preparation of transposed and replicated V
        if (inV != null)
            inV = inV.mapToPair(new TransposeFactorIndexesFunction()).flatMapToPair(new ReplicateBlockFunction(rlen, brlen, false));
        // functions calls w/ two rdd inputs
        if (inU != null && inV == null && inW == null)
            out = in.join(inU).mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
        else if (inU == null && inV != null && inW == null)
            out = in.join(inV).mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
        else if (inU == null && inV == null && inW != null)
            out = in.join(inW).mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
        else // function calls w/ three rdd inputs
        if (inU != null && inV != null && inW == null)
            out = in.join(inU).join(inV).mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
        else if (inU != null && inV == null && inW != null)
            out = in.join(inU).join(inW).mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
        else if (inU == null && inV != null && inW != null)
            out = in.join(inV).join(inW).mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
        else if (inU == null && inV == null && inW == null) {
            out = in.mapPartitionsToPair(new RDDQuaternaryFunction1(qop, bc1, bc2), false);
        } else
            // function call w/ four rdd inputs
            // need keys in case of wdivmm
            out = in.join(inU).join(inV).join(inW).mapToPair(new RDDQuaternaryFunction4(qop));
        // keep variable names for lineage maintenance
        if (inU == null)
            bcVars.add(input2.getName());
        else
            rddVars.add(input2.getName());
        if (inV == null)
            bcVars.add(input3.getName());
        else
            rddVars.add(input3.getName());
        if (inW != null)
            rddVars.add(_input4.getName());
    }
    // output handling, incl aggregation
    if (// map/redwsloss, map/redwcemm
    qop.wtype1 != null || qop.wtype4 != null) {
        // full aggregate and cast to scalar
        MatrixBlock tmp = RDDAggregateUtils.sumStable(out);
        DoubleObject ret = new DoubleObject(tmp.getValue(0, 0));
        sec.setVariable(output.getName(), ret);
    } else // map/redwsigmoid, map/redwdivmm, map/redwumm
    {
        // aggregation if required (map/redwdivmm)
        if (qop.wtype3 != null && !qop.wtype3.isBasic())
            out = RDDAggregateUtils.sumByKeyStable(out, false);
        // put output RDD handle into symbol table
        sec.setRDDHandleForVariable(output.getName(), out);
        // maintain lineage information for output rdd
        for (String rddVar : rddVars) sec.addLineageRDD(output.getName(), rddVar);
        for (String bcVar : bcVars) sec.addLineageBroadcast(output.getName(), bcVar);
        // update matrix characteristics
        updateOutputMatrixCharacteristics(sec, qop);
    }
}
Also used : QuaternaryOperator(org.apache.sysml.runtime.matrix.operators.QuaternaryOperator) FilterNonEmptyBlocksFunction(org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction) MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) DoubleObject(org.apache.sysml.runtime.instructions.cp.DoubleObject) ArrayList(java.util.ArrayList) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) SparkExecutionContext(org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext) ReplicateBlockFunction(org.apache.sysml.runtime.instructions.spark.functions.ReplicateBlockFunction)

Example 2 with ReplicateBlockFunction

use of org.apache.sysml.runtime.instructions.spark.functions.ReplicateBlockFunction in project systemml by apache.

the class SpoofSPInstruction method createJoinedInputRDD.

private static JavaPairRDD<MatrixIndexes, MatrixBlock[]> createJoinedInputRDD(SparkExecutionContext sec, CPOperand[] inputs, boolean[] bcVect, boolean outer) {
    // get input rdd for main input
    int main = getMainInputIndex(inputs, bcVect);
    MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(inputs[main].getName());
    JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable(inputs[main].getName());
    JavaPairRDD<MatrixIndexes, MatrixBlock[]> ret = in.mapValues(new MapInputSignature());
    for (int i = 0; i < inputs.length; i++) if (i != main && inputs[i].getDataType().isMatrix() && !bcVect[i]) {
        // create side input rdd
        String varname = inputs[i].getName();
        JavaPairRDD<MatrixIndexes, MatrixBlock> tmp = sec.getBinaryBlockRDDHandleForVariable(varname);
        MatrixCharacteristics mcTmp = sec.getMatrixCharacteristics(varname);
        // replicate blocks if mismatch with main input
        if (outer && i == 2)
            tmp = tmp.flatMapToPair(new ReplicateRightFactorFunction(mcIn.getRows(), mcIn.getRowsPerBlock()));
        else if (mcIn.getNumRowBlocks() > mcTmp.getNumRowBlocks())
            tmp = tmp.flatMapToPair(new ReplicateBlockFunction(mcIn.getRows(), mcIn.getRowsPerBlock(), false));
        else if (mcIn.getNumColBlocks() > mcTmp.getNumColBlocks())
            tmp = tmp.flatMapToPair(new ReplicateBlockFunction(mcIn.getCols(), mcIn.getColsPerBlock(), true));
        // join main and side inputs and consolidate signature
        ret = ret.join(tmp).mapValues(new MapJoinSignature());
    }
    return ret;
}
Also used : MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) ReplicateBlockFunction(org.apache.sysml.runtime.instructions.spark.functions.ReplicateBlockFunction)

Example 3 with ReplicateBlockFunction

use of org.apache.sysml.runtime.instructions.spark.functions.ReplicateBlockFunction in project incubator-systemml by apache.

the class SpoofSPInstruction method createJoinedInputRDD.

private static JavaPairRDD<MatrixIndexes, MatrixBlock[]> createJoinedInputRDD(SparkExecutionContext sec, CPOperand[] inputs, boolean[] bcVect, boolean outer) {
    // get input rdd for main input
    int main = getMainInputIndex(inputs, bcVect);
    MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(inputs[main].getName());
    JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable(inputs[main].getName());
    JavaPairRDD<MatrixIndexes, MatrixBlock[]> ret = in.mapValues(new MapInputSignature());
    for (int i = 0; i < inputs.length; i++) if (i != main && inputs[i].getDataType().isMatrix() && !bcVect[i]) {
        // create side input rdd
        String varname = inputs[i].getName();
        JavaPairRDD<MatrixIndexes, MatrixBlock> tmp = sec.getBinaryBlockRDDHandleForVariable(varname);
        MatrixCharacteristics mcTmp = sec.getMatrixCharacteristics(varname);
        // replicate blocks if mismatch with main input
        if (outer && i == 2)
            tmp = tmp.flatMapToPair(new ReplicateRightFactorFunction(mcIn.getRows(), mcIn.getRowsPerBlock()));
        else if (mcIn.getNumRowBlocks() > mcTmp.getNumRowBlocks())
            tmp = tmp.flatMapToPair(new ReplicateBlockFunction(mcIn.getRows(), mcIn.getRowsPerBlock(), false));
        else if (mcIn.getNumColBlocks() > mcTmp.getNumColBlocks())
            tmp = tmp.flatMapToPair(new ReplicateBlockFunction(mcIn.getCols(), mcIn.getColsPerBlock(), true));
        // join main and side inputs and consolidate signature
        ret = ret.join(tmp).mapValues(new MapJoinSignature());
    }
    return ret;
}
Also used : MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) ReplicateBlockFunction(org.apache.sysml.runtime.instructions.spark.functions.ReplicateBlockFunction)

Example 4 with ReplicateBlockFunction

use of org.apache.sysml.runtime.instructions.spark.functions.ReplicateBlockFunction in project systemml by apache.

the class QuaternarySPInstruction method processInstruction.

@Override
public void processInstruction(ExecutionContext ec) {
    SparkExecutionContext sec = (SparkExecutionContext) ec;
    QuaternaryOperator qop = (QuaternaryOperator) _optr;
    // tracking of rdds and broadcasts (for lineage maintenance)
    ArrayList<String> rddVars = new ArrayList<>();
    ArrayList<String> bcVars = new ArrayList<>();
    JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable(input1.getName());
    JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
    MatrixCharacteristics inMc = sec.getMatrixCharacteristics(input1.getName());
    long rlen = inMc.getRows();
    long clen = inMc.getCols();
    int brlen = inMc.getRowsPerBlock();
    int bclen = inMc.getColsPerBlock();
    // (map/redwsloss, map/redwcemm); safe because theses ops produce a scalar
    if (qop.wtype1 != null || qop.wtype4 != null) {
        in = in.filter(new FilterNonEmptyBlocksFunction());
    }
    // map-side only operation (one rdd input, two broadcasts)
    if (WeightedSquaredLoss.OPCODE.equalsIgnoreCase(getOpcode()) || WeightedSigmoid.OPCODE.equalsIgnoreCase(getOpcode()) || WeightedDivMM.OPCODE.equalsIgnoreCase(getOpcode()) || WeightedCrossEntropy.OPCODE.equalsIgnoreCase(getOpcode()) || WeightedUnaryMM.OPCODE.equalsIgnoreCase(getOpcode())) {
        PartitionedBroadcast<MatrixBlock> bc1 = sec.getBroadcastForVariable(input2.getName());
        PartitionedBroadcast<MatrixBlock> bc2 = sec.getBroadcastForVariable(input3.getName());
        // partitioning-preserving mappartitions (key access required for broadcast loopkup)
        // only wdivmm changes keys
        boolean noKeyChange = (qop.wtype3 == null || qop.wtype3.isBasic());
        out = in.mapPartitionsToPair(new RDDQuaternaryFunction1(qop, bc1, bc2), noKeyChange);
        rddVars.add(input1.getName());
        bcVars.add(input2.getName());
        bcVars.add(input3.getName());
    } else // reduce-side operation (two/three/four rdd inputs, zero/one/two broadcasts)
    {
        PartitionedBroadcast<MatrixBlock> bc1 = _cacheU ? sec.getBroadcastForVariable(input2.getName()) : null;
        PartitionedBroadcast<MatrixBlock> bc2 = _cacheV ? sec.getBroadcastForVariable(input3.getName()) : null;
        JavaPairRDD<MatrixIndexes, MatrixBlock> inU = (!_cacheU) ? sec.getBinaryBlockRDDHandleForVariable(input2.getName()) : null;
        JavaPairRDD<MatrixIndexes, MatrixBlock> inV = (!_cacheV) ? sec.getBinaryBlockRDDHandleForVariable(input3.getName()) : null;
        JavaPairRDD<MatrixIndexes, MatrixBlock> inW = (qop.hasFourInputs() && !_input4.isLiteral()) ? sec.getBinaryBlockRDDHandleForVariable(_input4.getName()) : null;
        // preparation of transposed and replicated U
        if (inU != null)
            inU = inU.flatMapToPair(new ReplicateBlockFunction(clen, bclen, true));
        // preparation of transposed and replicated V
        if (inV != null)
            inV = inV.mapToPair(new TransposeFactorIndexesFunction()).flatMapToPair(new ReplicateBlockFunction(rlen, brlen, false));
        // functions calls w/ two rdd inputs
        if (inU != null && inV == null && inW == null)
            out = in.join(inU).mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
        else if (inU == null && inV != null && inW == null)
            out = in.join(inV).mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
        else if (inU == null && inV == null && inW != null)
            out = in.join(inW).mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
        else // function calls w/ three rdd inputs
        if (inU != null && inV != null && inW == null)
            out = in.join(inU).join(inV).mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
        else if (inU != null && inV == null && inW != null)
            out = in.join(inU).join(inW).mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
        else if (inU == null && inV != null && inW != null)
            out = in.join(inV).join(inW).mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
        else if (inU == null && inV == null && inW == null) {
            out = in.mapPartitionsToPair(new RDDQuaternaryFunction1(qop, bc1, bc2), false);
        } else
            // function call w/ four rdd inputs
            // need keys in case of wdivmm
            out = in.join(inU).join(inV).join(inW).mapToPair(new RDDQuaternaryFunction4(qop));
        // keep variable names for lineage maintenance
        if (inU == null)
            bcVars.add(input2.getName());
        else
            rddVars.add(input2.getName());
        if (inV == null)
            bcVars.add(input3.getName());
        else
            rddVars.add(input3.getName());
        if (inW != null)
            rddVars.add(_input4.getName());
    }
    // output handling, incl aggregation
    if (// map/redwsloss, map/redwcemm
    qop.wtype1 != null || qop.wtype4 != null) {
        // full aggregate and cast to scalar
        MatrixBlock tmp = RDDAggregateUtils.sumStable(out);
        DoubleObject ret = new DoubleObject(tmp.getValue(0, 0));
        sec.setVariable(output.getName(), ret);
    } else // map/redwsigmoid, map/redwdivmm, map/redwumm
    {
        // aggregation if required (map/redwdivmm)
        if (qop.wtype3 != null && !qop.wtype3.isBasic())
            out = RDDAggregateUtils.sumByKeyStable(out, false);
        // put output RDD handle into symbol table
        sec.setRDDHandleForVariable(output.getName(), out);
        // maintain lineage information for output rdd
        for (String rddVar : rddVars) sec.addLineageRDD(output.getName(), rddVar);
        for (String bcVar : bcVars) sec.addLineageBroadcast(output.getName(), bcVar);
        // update matrix characteristics
        updateOutputMatrixCharacteristics(sec, qop);
    }
}
Also used : QuaternaryOperator(org.apache.sysml.runtime.matrix.operators.QuaternaryOperator) FilterNonEmptyBlocksFunction(org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction) MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) DoubleObject(org.apache.sysml.runtime.instructions.cp.DoubleObject) ArrayList(java.util.ArrayList) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) SparkExecutionContext(org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext) ReplicateBlockFunction(org.apache.sysml.runtime.instructions.spark.functions.ReplicateBlockFunction)

Aggregations

ReplicateBlockFunction (org.apache.sysml.runtime.instructions.spark.functions.ReplicateBlockFunction)4 MatrixCharacteristics (org.apache.sysml.runtime.matrix.MatrixCharacteristics)4 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)4 MatrixIndexes (org.apache.sysml.runtime.matrix.data.MatrixIndexes)4 ArrayList (java.util.ArrayList)2 JavaPairRDD (org.apache.spark.api.java.JavaPairRDD)2 SparkExecutionContext (org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext)2 DoubleObject (org.apache.sysml.runtime.instructions.cp.DoubleObject)2 FilterNonEmptyBlocksFunction (org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction)2 QuaternaryOperator (org.apache.sysml.runtime.matrix.operators.QuaternaryOperator)2