Search in sources :

Example 1 with QuaternaryOperator

use of org.apache.sysml.runtime.matrix.operators.QuaternaryOperator in project incubator-systemml by apache.

the class QuaternarySPInstruction method processInstruction.

@Override
public void processInstruction(ExecutionContext ec) {
    SparkExecutionContext sec = (SparkExecutionContext) ec;
    QuaternaryOperator qop = (QuaternaryOperator) _optr;
    // tracking of rdds and broadcasts (for lineage maintenance)
    ArrayList<String> rddVars = new ArrayList<>();
    ArrayList<String> bcVars = new ArrayList<>();
    JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable(input1.getName());
    JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
    MatrixCharacteristics inMc = sec.getMatrixCharacteristics(input1.getName());
    long rlen = inMc.getRows();
    long clen = inMc.getCols();
    int brlen = inMc.getRowsPerBlock();
    int bclen = inMc.getColsPerBlock();
    // (map/redwsloss, map/redwcemm); safe because theses ops produce a scalar
    if (qop.wtype1 != null || qop.wtype4 != null) {
        in = in.filter(new FilterNonEmptyBlocksFunction());
    }
    // map-side only operation (one rdd input, two broadcasts)
    if (WeightedSquaredLoss.OPCODE.equalsIgnoreCase(getOpcode()) || WeightedSigmoid.OPCODE.equalsIgnoreCase(getOpcode()) || WeightedDivMM.OPCODE.equalsIgnoreCase(getOpcode()) || WeightedCrossEntropy.OPCODE.equalsIgnoreCase(getOpcode()) || WeightedUnaryMM.OPCODE.equalsIgnoreCase(getOpcode())) {
        PartitionedBroadcast<MatrixBlock> bc1 = sec.getBroadcastForVariable(input2.getName());
        PartitionedBroadcast<MatrixBlock> bc2 = sec.getBroadcastForVariable(input3.getName());
        // partitioning-preserving mappartitions (key access required for broadcast loopkup)
        // only wdivmm changes keys
        boolean noKeyChange = (qop.wtype3 == null || qop.wtype3.isBasic());
        out = in.mapPartitionsToPair(new RDDQuaternaryFunction1(qop, bc1, bc2), noKeyChange);
        rddVars.add(input1.getName());
        bcVars.add(input2.getName());
        bcVars.add(input3.getName());
    } else // reduce-side operation (two/three/four rdd inputs, zero/one/two broadcasts)
    {
        PartitionedBroadcast<MatrixBlock> bc1 = _cacheU ? sec.getBroadcastForVariable(input2.getName()) : null;
        PartitionedBroadcast<MatrixBlock> bc2 = _cacheV ? sec.getBroadcastForVariable(input3.getName()) : null;
        JavaPairRDD<MatrixIndexes, MatrixBlock> inU = (!_cacheU) ? sec.getBinaryBlockRDDHandleForVariable(input2.getName()) : null;
        JavaPairRDD<MatrixIndexes, MatrixBlock> inV = (!_cacheV) ? sec.getBinaryBlockRDDHandleForVariable(input3.getName()) : null;
        JavaPairRDD<MatrixIndexes, MatrixBlock> inW = (qop.hasFourInputs() && !_input4.isLiteral()) ? sec.getBinaryBlockRDDHandleForVariable(_input4.getName()) : null;
        // preparation of transposed and replicated U
        if (inU != null)
            inU = inU.flatMapToPair(new ReplicateBlockFunction(clen, bclen, true));
        // preparation of transposed and replicated V
        if (inV != null)
            inV = inV.mapToPair(new TransposeFactorIndexesFunction()).flatMapToPair(new ReplicateBlockFunction(rlen, brlen, false));
        // functions calls w/ two rdd inputs
        if (inU != null && inV == null && inW == null)
            out = in.join(inU).mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
        else if (inU == null && inV != null && inW == null)
            out = in.join(inV).mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
        else if (inU == null && inV == null && inW != null)
            out = in.join(inW).mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
        else // function calls w/ three rdd inputs
        if (inU != null && inV != null && inW == null)
            out = in.join(inU).join(inV).mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
        else if (inU != null && inV == null && inW != null)
            out = in.join(inU).join(inW).mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
        else if (inU == null && inV != null && inW != null)
            out = in.join(inV).join(inW).mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
        else if (inU == null && inV == null && inW == null) {
            out = in.mapPartitionsToPair(new RDDQuaternaryFunction1(qop, bc1, bc2), false);
        } else
            // function call w/ four rdd inputs
            // need keys in case of wdivmm
            out = in.join(inU).join(inV).join(inW).mapToPair(new RDDQuaternaryFunction4(qop));
        // keep variable names for lineage maintenance
        if (inU == null)
            bcVars.add(input2.getName());
        else
            rddVars.add(input2.getName());
        if (inV == null)
            bcVars.add(input3.getName());
        else
            rddVars.add(input3.getName());
        if (inW != null)
            rddVars.add(_input4.getName());
    }
    // output handling, incl aggregation
    if (// map/redwsloss, map/redwcemm
    qop.wtype1 != null || qop.wtype4 != null) {
        // full aggregate and cast to scalar
        MatrixBlock tmp = RDDAggregateUtils.sumStable(out);
        DoubleObject ret = new DoubleObject(tmp.getValue(0, 0));
        sec.setVariable(output.getName(), ret);
    } else // map/redwsigmoid, map/redwdivmm, map/redwumm
    {
        // aggregation if required (map/redwdivmm)
        if (qop.wtype3 != null && !qop.wtype3.isBasic())
            out = RDDAggregateUtils.sumByKeyStable(out, false);
        // put output RDD handle into symbol table
        sec.setRDDHandleForVariable(output.getName(), out);
        // maintain lineage information for output rdd
        for (String rddVar : rddVars) sec.addLineageRDD(output.getName(), rddVar);
        for (String bcVar : bcVars) sec.addLineageBroadcast(output.getName(), bcVar);
        // update matrix characteristics
        updateOutputMatrixCharacteristics(sec, qop);
    }
}
Also used : QuaternaryOperator(org.apache.sysml.runtime.matrix.operators.QuaternaryOperator) FilterNonEmptyBlocksFunction(org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction) MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) DoubleObject(org.apache.sysml.runtime.instructions.cp.DoubleObject) ArrayList(java.util.ArrayList) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) SparkExecutionContext(org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext) ReplicateBlockFunction(org.apache.sysml.runtime.instructions.spark.functions.ReplicateBlockFunction)

Example 2 with QuaternaryOperator

use of org.apache.sysml.runtime.matrix.operators.QuaternaryOperator in project incubator-systemml by apache.

the class QuaternaryCPInstruction method processInstruction.

@Override
public void processInstruction(ExecutionContext ec) {
    QuaternaryOperator qop = (QuaternaryOperator) _optr;
    MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName(), getExtendedOpcode());
    MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName(), getExtendedOpcode());
    MatrixBlock matBlock3 = ec.getMatrixInput(input3.getName(), getExtendedOpcode());
    MatrixBlock matBlock4 = null;
    if (qop.hasFourInputs()) {
        if (input4.getDataType() == DataType.SCALAR) {
            matBlock4 = new MatrixBlock(1, 1, false);
            final double eps = ec.getScalarInput(input4.getName(), input4.getValueType(), input4.isLiteral()).getDoubleValue();
            matBlock4.quickSetValue(0, 0, eps);
        } else {
            matBlock4 = ec.getMatrixInput(input4.getName(), getExtendedOpcode());
        }
    }
    // core execute
    MatrixBlock out = matBlock1.quaternaryOperations(qop, matBlock2, matBlock3, matBlock4, new MatrixBlock(), _numThreads);
    // release inputs and output
    ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
    ec.releaseMatrixInput(input2.getName(), getExtendedOpcode());
    ec.releaseMatrixInput(input3.getName(), getExtendedOpcode());
    if (qop.wtype1 != null || qop.wtype4 != null) {
        // wsloss/wcemm
        if ((qop.wtype1 != null && qop.wtype1.hasFourInputs()) || (qop.wtype4 != null && qop.wtype4.hasFourInputs()))
            if (input4.getDataType() == DataType.MATRIX) {
                ec.releaseMatrixInput(input4.getName(), getExtendedOpcode());
            }
        ec.setVariable(output.getName(), new DoubleObject(out.quickGetValue(0, 0)));
    } else {
        // wsigmoid / wdivmm / wumm
        if (qop.wtype3 != null && qop.wtype3.hasFourInputs())
            if (input4.getDataType() == DataType.MATRIX) {
                ec.releaseMatrixInput(input4.getName(), getExtendedOpcode());
            }
        ec.setMatrixOutput(output.getName(), out, getExtendedOpcode());
    }
}
Also used : QuaternaryOperator(org.apache.sysml.runtime.matrix.operators.QuaternaryOperator) MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock)

Example 3 with QuaternaryOperator

use of org.apache.sysml.runtime.matrix.operators.QuaternaryOperator in project incubator-systemml by apache.

the class QuaternaryInstruction method computeMatrixCharacteristics.

public void computeMatrixCharacteristics(MatrixCharacteristics mc1, MatrixCharacteristics mc2, MatrixCharacteristics mc3, MatrixCharacteristics dimOut) {
    QuaternaryOperator qop = (QuaternaryOperator) optr;
    if (qop.wtype1 != null || qop.wtype4 != null) {
        // wsloss/wcemm
        // output size independent of chain type (scalar)
        dimOut.set(1, 1, mc1.getRowsPerBlock(), mc1.getColsPerBlock());
    } else if (qop.wtype2 != null || qop.wtype5 != null) {
        // wsigmoid/wumm
        // output size determined by main input
        dimOut.set(mc1.getRows(), mc1.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
    } else if (qop.wtype3 != null) {
        // wdivmm
        // note: cannot directly consume mc2 or mc3 for redwdivmm because rep instruction changed
        // the relevant dimensions; as a workaround the original dims are passed via nnz
        boolean mapwdivmm = _cacheU && _cacheV;
        long rank = qop.wtype3.isLeft() ? mapwdivmm ? mc3.getCols() : mc3.getNonZeros() : mapwdivmm ? mc2.getCols() : mc2.getNonZeros();
        MatrixCharacteristics mcTmp = qop.wtype3.computeOutputCharacteristics(mc1.getRows(), mc1.getCols(), rank);
        dimOut.set(mcTmp.getRows(), mcTmp.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
    }
}
Also used : QuaternaryOperator(org.apache.sysml.runtime.matrix.operators.QuaternaryOperator) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics)

Example 4 with QuaternaryOperator

use of org.apache.sysml.runtime.matrix.operators.QuaternaryOperator in project incubator-systemml by apache.

the class QuaternaryInstruction method processInstruction.

@Override
public void processInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues, IndexedMatrixValue tempValue, IndexedMatrixValue zeroInput, int blockRowFactor, int blockColFactor) {
    QuaternaryOperator qop = (QuaternaryOperator) optr;
    ArrayList<IndexedMatrixValue> blkList = cachedValues.get(_input1);
    if (blkList != null)
        for (IndexedMatrixValue imv : blkList) {
            // Step 1: prepare inputs and output
            if (imv == null)
                continue;
            MatrixIndexes inIx = imv.getIndexes();
            MatrixBlock inVal = (MatrixBlock) imv.getValue();
            // allocate space for the output value
            IndexedMatrixValue iout = null;
            if (output == _input1)
                iout = tempValue;
            else
                iout = cachedValues.holdPlace(output, valueClass);
            MatrixIndexes outIx = iout.getIndexes();
            MatrixValue outVal = iout.getValue();
            // Step 2: get remaining inputs: Wij, Ui, Vj
            MatrixBlock Xij = inVal;
            // get Wij if existing (null of WeightsType.NONE or WSigmoid any type)
            IndexedMatrixValue iWij = (_input4 != -1) ? cachedValues.getFirst(_input4) : null;
            MatrixValue Wij = (iWij != null) ? iWij.getValue() : null;
            if (null == Wij && qop.hasFourInputs()) {
                MatrixBlock mb = new MatrixBlock(1, 1, false);
                String[] parts = InstructionUtils.getInstructionParts(instString);
                mb.quickSetValue(0, 0, Double.valueOf(parts[4]));
                Wij = mb;
            }
            // get Ui and Vj, potentially through distributed cache
            MatrixValue Ui = // U
            (!_cacheU) ? // U
            cachedValues.getFirst(_input2).getValue() : MRBaseForCommonInstructions.dcValues.get(_input2).getDataBlock((int) inIx.getRowIndex(), 1).getValue();
            MatrixValue Vj = // t(V)
            (!_cacheV) ? // t(V)
            cachedValues.getFirst(_input3).getValue() : MRBaseForCommonInstructions.dcValues.get(_input3).getDataBlock((int) inIx.getColumnIndex(), 1).getValue();
            // handle special input case: //V through shuffle -> t(V)
            if (Ui.getNumColumns() != Vj.getNumColumns()) {
                Vj = LibMatrixReorg.reorg((MatrixBlock) Vj, new MatrixBlock(Vj.getNumColumns(), Vj.getNumRows(), Vj.isInSparseFormat()), new ReorgOperator(SwapIndex.getSwapIndexFnObject()));
            }
            // Step 3: process instruction
            Xij.quaternaryOperations(qop, (MatrixBlock) Ui, (MatrixBlock) Vj, (MatrixBlock) Wij, (MatrixBlock) outVal);
            if (qop.wtype1 != null || qop.wtype4 != null)
                // wsloss
                outIx.setIndexes(1, 1);
            else if (qop.wtype2 != null || qop.wtype5 != null || qop.wtype3 != null && qop.wtype3.isBasic())
                // wsigmoid/wdivmm-basic
                outIx.setIndexes(inIx);
            else {
                // wdivmm
                boolean left = qop.wtype3.isLeft();
                outIx.setIndexes(left ? inIx.getColumnIndex() : inIx.getRowIndex(), 1);
            }
            // put the output value in the cache
            if (iout == tempValue)
                cachedValues.add(output, iout);
        }
}
Also used : QuaternaryOperator(org.apache.sysml.runtime.matrix.operators.QuaternaryOperator) MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) IndexedMatrixValue(org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue) MatrixValue(org.apache.sysml.runtime.matrix.data.MatrixValue) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) ReorgOperator(org.apache.sysml.runtime.matrix.operators.ReorgOperator) IndexedMatrixValue(org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue)

Example 5 with QuaternaryOperator

use of org.apache.sysml.runtime.matrix.operators.QuaternaryOperator in project systemml by apache.

the class QuaternaryCPInstruction method processInstruction.

@Override
public void processInstruction(ExecutionContext ec) {
    QuaternaryOperator qop = (QuaternaryOperator) _optr;
    MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName(), getExtendedOpcode());
    MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName(), getExtendedOpcode());
    MatrixBlock matBlock3 = ec.getMatrixInput(input3.getName(), getExtendedOpcode());
    MatrixBlock matBlock4 = null;
    if (qop.hasFourInputs()) {
        if (input4.getDataType() == DataType.SCALAR) {
            matBlock4 = new MatrixBlock(1, 1, false);
            final double eps = ec.getScalarInput(input4.getName(), input4.getValueType(), input4.isLiteral()).getDoubleValue();
            matBlock4.quickSetValue(0, 0, eps);
        } else {
            matBlock4 = ec.getMatrixInput(input4.getName(), getExtendedOpcode());
        }
    }
    // core execute
    MatrixBlock out = matBlock1.quaternaryOperations(qop, matBlock2, matBlock3, matBlock4, new MatrixBlock(), _numThreads);
    // release inputs and output
    ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
    ec.releaseMatrixInput(input2.getName(), getExtendedOpcode());
    ec.releaseMatrixInput(input3.getName(), getExtendedOpcode());
    if (qop.wtype1 != null || qop.wtype4 != null) {
        // wsloss/wcemm
        if ((qop.wtype1 != null && qop.wtype1.hasFourInputs()) || (qop.wtype4 != null && qop.wtype4.hasFourInputs()))
            if (input4.getDataType() == DataType.MATRIX) {
                ec.releaseMatrixInput(input4.getName(), getExtendedOpcode());
            }
        ec.setVariable(output.getName(), new DoubleObject(out.quickGetValue(0, 0)));
    } else {
        // wsigmoid / wdivmm / wumm
        if (qop.wtype3 != null && qop.wtype3.hasFourInputs())
            if (input4.getDataType() == DataType.MATRIX) {
                ec.releaseMatrixInput(input4.getName(), getExtendedOpcode());
            }
        ec.setMatrixOutput(output.getName(), out, getExtendedOpcode());
    }
}
Also used : QuaternaryOperator(org.apache.sysml.runtime.matrix.operators.QuaternaryOperator) MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock)

Aggregations

QuaternaryOperator (org.apache.sysml.runtime.matrix.operators.QuaternaryOperator)14 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)6 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)6 WDivMMType (org.apache.sysml.lops.WeightedDivMM.WDivMMType)4 WeightsType (org.apache.sysml.lops.WeightedSquaredLoss.WeightsType)4 WUMMType (org.apache.sysml.lops.WeightedUnaryMM.WUMMType)4 MatrixCharacteristics (org.apache.sysml.runtime.matrix.MatrixCharacteristics)4 MatrixIndexes (org.apache.sysml.runtime.matrix.data.MatrixIndexes)4 ArrayList (java.util.ArrayList)2 WCeMMType (org.apache.sysml.lops.WeightedCrossEntropy.WCeMMType)2 SparkExecutionContext (org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext)2 CPOperand (org.apache.sysml.runtime.instructions.cp.CPOperand)2 DoubleObject (org.apache.sysml.runtime.instructions.cp.DoubleObject)2 FilterNonEmptyBlocksFunction (org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction)2 ReplicateBlockFunction (org.apache.sysml.runtime.instructions.spark.functions.ReplicateBlockFunction)2 MatrixValue (org.apache.sysml.runtime.matrix.data.MatrixValue)2 IndexedMatrixValue (org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue)2 ReorgOperator (org.apache.sysml.runtime.matrix.operators.ReorgOperator)2