Search in sources :

Example 1 with Ctable

use of org.apache.sysml.lops.Ctable in project systemml by apache.

the class CtableSPInstruction method processInstruction.

@Override
public void processInstruction(ExecutionContext ec) {
    SparkExecutionContext sec = (SparkExecutionContext) ec;
    // get input rdd handle
    JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(input1.getName());
    JavaPairRDD<MatrixIndexes, MatrixBlock> in2 = null;
    JavaPairRDD<MatrixIndexes, MatrixBlock> in3 = null;
    double scalar_input2 = -1, scalar_input3 = -1;
    Ctable.OperationTypes ctableOp = Ctable.findCtableOperationByInputDataTypes(input1.getDataType(), input2.getDataType(), input3.getDataType());
    ctableOp = _isExpand ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ctableOp;
    MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(input1.getName());
    MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
    // First get the block sizes and then set them as -1 to allow for binary cell reblock
    int brlen = mc1.getRowsPerBlock();
    int bclen = mc1.getColsPerBlock();
    JavaPairRDD<MatrixIndexes, ArrayList<MatrixBlock>> inputMBs = null;
    JavaPairRDD<MatrixIndexes, CTableMap> ctables = null;
    JavaPairRDD<MatrixIndexes, Double> bincellsNoFilter = null;
    boolean setLineage2 = false;
    boolean setLineage3 = false;
    switch(ctableOp) {
        case // (VECTOR)
        CTABLE_TRANSFORM:
            // F=ctable(A,B,W)
            in2 = sec.getBinaryBlockRDDHandleForVariable(input2.getName());
            in3 = sec.getBinaryBlockRDDHandleForVariable(input3.getName());
            setLineage2 = true;
            setLineage3 = true;
            inputMBs = in1.cogroup(in2).cogroup(in3).mapToPair(new MapThreeMBIterableIntoAL());
            ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2, scalar_input3, this.instString, (SimpleOperator) _optr, _ignoreZeros));
            break;
        case // (VECTOR)
        CTABLE_EXPAND_SCALAR_WEIGHT:
            // F = ctable(seq,A) or F = ctable(seq,B,1)
            scalar_input3 = sec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
            if (scalar_input3 == 1) {
                in2 = sec.getBinaryBlockRDDHandleForVariable(input2.getName());
                setLineage2 = true;
                bincellsNoFilter = in2.flatMapToPair(new ExpandScalarCtableOperation(brlen));
                break;
            }
        case // (VECTOR/MATRIX)
        CTABLE_TRANSFORM_SCALAR_WEIGHT:
            // F = ctable(A,B) or F = ctable(A,B,1)
            in2 = sec.getBinaryBlockRDDHandleForVariable(input2.getName());
            setLineage2 = true;
            scalar_input3 = sec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
            inputMBs = in1.cogroup(in2).mapToPair(new MapTwoMBIterableIntoAL());
            ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2, scalar_input3, this.instString, (SimpleOperator) _optr, _ignoreZeros));
            break;
        case // (VECTOR)
        CTABLE_TRANSFORM_HISTOGRAM:
            // F=ctable(A,1) or F = ctable(A,1,1)
            scalar_input2 = sec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue();
            scalar_input3 = sec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
            inputMBs = in1.mapToPair(new MapMBIntoAL());
            ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2, scalar_input3, this.instString, (SimpleOperator) _optr, _ignoreZeros));
            break;
        case // (VECTOR)
        CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM:
            // F=ctable(A,1,W)
            in3 = sec.getBinaryBlockRDDHandleForVariable(input3.getName());
            setLineage3 = true;
            scalar_input2 = sec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue();
            inputMBs = in1.cogroup(in3).mapToPair(new MapTwoMBIterableIntoAL());
            ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2, scalar_input3, this.instString, (SimpleOperator) _optr, _ignoreZeros));
            break;
        default:
            throw new DMLRuntimeException("Encountered an invalid ctable operation (" + ctableOp + ") while executing instruction: " + this.toString());
    }
    // Now perform aggregation on ctables to get binaryCells
    if (bincellsNoFilter == null && ctables != null) {
        bincellsNoFilter = ctables.values().flatMapToPair(new ExtractBinaryCellsFromCTable());
        bincellsNoFilter = RDDAggregateUtils.sumCellsByKeyStable(bincellsNoFilter);
    } else if (!(bincellsNoFilter != null && ctables == null)) {
        throw new DMLRuntimeException("Incorrect ctable operation");
    }
    // handle known/unknown dimensions
    long outputDim1 = (_dim1Literal ? (long) Double.parseDouble(_outDim1) : (sec.getScalarInput(_outDim1, ValueType.DOUBLE, false)).getLongValue());
    long outputDim2 = (_dim2Literal ? (long) Double.parseDouble(_outDim2) : (sec.getScalarInput(_outDim2, ValueType.DOUBLE, false)).getLongValue());
    MatrixCharacteristics mcBinaryCells = null;
    boolean findDimensions = (outputDim1 == -1 && outputDim2 == -1);
    if (!findDimensions) {
        if ((outputDim1 == -1 && outputDim2 != -1) || (outputDim1 != -1 && outputDim2 == -1))
            throw new DMLRuntimeException("Incorrect output dimensions passed to TernarySPInstruction:" + outputDim1 + " " + outputDim2);
        else
            mcBinaryCells = new MatrixCharacteristics(outputDim1, outputDim2, brlen, bclen);
        // filtering according to given dimensions
        bincellsNoFilter = bincellsNoFilter.filter(new FilterCells(mcBinaryCells.getRows(), mcBinaryCells.getCols()));
    }
    // convert double values to matrix cell
    JavaPairRDD<MatrixIndexes, MatrixCell> binaryCells = bincellsNoFilter.mapToPair(new ConvertToBinaryCell());
    // find dimensions if necessary (w/ cache for reblock)
    if (findDimensions) {
        binaryCells = SparkUtils.cacheBinaryCellRDD(binaryCells);
        mcBinaryCells = SparkUtils.computeMatrixCharacteristics(binaryCells);
    }
    // store output rdd handle
    sec.setRDDHandleForVariable(output.getName(), binaryCells);
    mcOut.set(mcBinaryCells);
    // Since we are outputing binary cells, we set block sizes = -1
    mcOut.setRowsPerBlock(-1);
    mcOut.setColsPerBlock(-1);
    sec.addLineageRDD(output.getName(), input1.getName());
    if (setLineage2)
        sec.addLineageRDD(output.getName(), input2.getName());
    if (setLineage3)
        sec.addLineageRDD(output.getName(), input3.getName());
}
Also used : MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) ArrayList(java.util.ArrayList) Ctable(org.apache.sysml.lops.Ctable) MatrixCell(org.apache.sysml.runtime.matrix.data.MatrixCell) SparkExecutionContext(org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) CTableMap(org.apache.sysml.runtime.matrix.data.CTableMap)

Example 2 with Ctable

use of org.apache.sysml.lops.Ctable in project systemml by apache.

the class TernaryOp method constructLopsCtable.

/**
 * Method to construct LOPs when op = CTABLE.
 */
private void constructLopsCtable() {
    if (_op != OpOp3.CTABLE)
        throw new HopsException("Unexpected operation: " + _op + ", expecting " + OpOp3.CTABLE);
    /*
		 * We must handle three different cases: case1 : all three
		 * inputs are vectors (e.g., F=ctable(A,B,W)) case2 : two
		 * vectors and one scalar (e.g., F=ctable(A,B)) case3 : one
		 * vector and two scalars (e.g., F=ctable(A))
		 */
    // identify the particular case
    // F=ctable(A,B,W)
    DataType dt1 = getInput().get(0).getDataType();
    DataType dt2 = getInput().get(1).getDataType();
    DataType dt3 = getInput().get(2).getDataType();
    Ctable.OperationTypes ternaryOpOrig = Ctable.findCtableOperationByInputDataTypes(dt1, dt2, dt3);
    // Compute lops for all inputs
    Lop[] inputLops = new Lop[getInput().size()];
    for (int i = 0; i < getInput().size(); i++) {
        inputLops[i] = getInput().get(i).constructLops();
    }
    ExecType et = optFindExecType();
    // reset reblock requirement (see MR ctable / construct lops)
    setRequiresReblock(false);
    if (et == ExecType.CP || et == ExecType.SPARK) {
        // for CP we support only ctable expand left
        Ctable.OperationTypes ternaryOp = isSequenceRewriteApplicable(true) ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ternaryOpOrig;
        boolean ignoreZeros = false;
        if (isMatrixIgnoreZeroRewriteApplicable()) {
            // table - rmempty - rshape
            ignoreZeros = true;
            inputLops[0] = ((ParameterizedBuiltinOp) getInput().get(0)).getTargetHop().getInput().get(0).constructLops();
            inputLops[1] = ((ParameterizedBuiltinOp) getInput().get(1)).getTargetHop().getInput().get(0).constructLops();
        }
        Ctable ternary = new Ctable(inputLops, ternaryOp, getDataType(), getValueType(), ignoreZeros, et);
        ternary.getOutputParameters().setDimensions(_dim1, _dim2, getRowsInBlock(), getColsInBlock(), -1);
        setLineNumbers(ternary);
        // force blocked output in CP (see below), otherwise binarycell
        if (et == ExecType.SPARK) {
            ternary.getOutputParameters().setDimensions(_dim1, _dim2, -1, -1, -1);
            setRequiresReblock(true);
        } else
            ternary.getOutputParameters().setDimensions(_dim1, _dim2, getRowsInBlock(), getColsInBlock(), -1);
        // ternary opt, w/o reblock in CP
        setLops(ternary);
    } else // MR
    {
        // for MR we support both ctable expand left and right
        Ctable.OperationTypes ternaryOp = isSequenceRewriteApplicable() ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ternaryOpOrig;
        Group group1 = null, group2 = null, group3 = null, group4 = null;
        group1 = new Group(inputLops[0], Group.OperationTypes.Sort, getDataType(), getValueType());
        group1.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        setLineNumbers(group1);
        Ctable ternary = null;
        // create "group" lops for MATRIX inputs
        switch(ternaryOp) {
            case CTABLE_TRANSFORM:
                // F = ctable(A,B,W)
                group2 = new Group(inputLops[1], Group.OperationTypes.Sort, getDataType(), getValueType());
                group2.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
                setLineNumbers(group2);
                group3 = new Group(inputLops[2], Group.OperationTypes.Sort, getDataType(), getValueType());
                group3.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
                setLineNumbers(group3);
                if (inputLops.length == 3)
                    ternary = new Ctable(new Lop[] { group1, group2, group3 }, ternaryOp, getDataType(), getValueType(), et);
                else
                    // output dimensions are given
                    ternary = new Ctable(new Lop[] { group1, group2, group3, inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
                break;
            case CTABLE_TRANSFORM_SCALAR_WEIGHT:
                // F = ctable(A,B) or F = ctable(A,B,1)
                group2 = new Group(inputLops[1], Group.OperationTypes.Sort, getDataType(), getValueType());
                group2.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
                setLineNumbers(group2);
                if (inputLops.length == 3)
                    ternary = new Ctable(new Lop[] { group1, group2, inputLops[2] }, ternaryOp, getDataType(), getValueType(), et);
                else
                    ternary = new Ctable(new Lop[] { group1, group2, inputLops[2], inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
                break;
            case CTABLE_EXPAND_SCALAR_WEIGHT:
                // F=ctable(seq(1,N),A) or F = ctable(seq,A,1)
                // left 1, right 0 (index of input data)
                int left = isSequenceRewriteApplicable(true) ? 1 : 0;
                Group group = new Group(getInput().get(left).constructLops(), Group.OperationTypes.Sort, getDataType(), getValueType());
                group.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
                if (inputLops.length == 3)
                    ternary = new Ctable(new Lop[] { // matrix
                    group, // weight
                    getInput().get(2).constructLops(), // left
                    new LiteralOp(left).constructLops() }, ternaryOp, getDataType(), getValueType(), et);
                else
                    ternary = new Ctable(new Lop[] { // matrix
                    group, // weight
                    getInput().get(2).constructLops(), // left
                    new LiteralOp(left).constructLops(), inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
                break;
            case CTABLE_TRANSFORM_HISTOGRAM:
                // F=ctable(A,1) or F = ctable(A,1,1)
                if (inputLops.length == 3)
                    ternary = new Ctable(new Lop[] { group1, getInput().get(1).constructLops(), getInput().get(2).constructLops() }, ternaryOp, getDataType(), getValueType(), et);
                else
                    ternary = new Ctable(new Lop[] { group1, getInput().get(1).constructLops(), getInput().get(2).constructLops(), inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
                break;
            case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM:
                // F=ctable(A,1,W)
                group3 = new Group(getInput().get(2).constructLops(), Group.OperationTypes.Sort, getDataType(), getValueType());
                group3.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
                setLineNumbers(group3);
                if (inputLops.length == 3)
                    ternary = new Ctable(new Lop[] { group1, getInput().get(1).constructLops(), group3 }, ternaryOp, getDataType(), getValueType(), et);
                else
                    ternary = new Ctable(new Lop[] { group1, getInput().get(1).constructLops(), group3, inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
                break;
            default:
                throw new HopsException("Invalid ternary operator type: " + _op);
        }
        // output dimensions are not known at compilation time
        ternary.getOutputParameters().setDimensions(_dim1, _dim2, (_dimInputsPresent ? getRowsInBlock() : -1), (_dimInputsPresent ? getColsInBlock() : -1), -1);
        setLineNumbers(ternary);
        Lop lctable = ternary;
        if (!(_disjointInputs || ternaryOp == Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT)) {
            // no need for aggregation if (1) input indexed disjoint	or one side is sequence	w/ 1 increment
            group4 = new Group(ternary, Group.OperationTypes.Sort, getDataType(), getValueType());
            group4.getOutputParameters().setDimensions(_dim1, _dim2, (_dimInputsPresent ? getRowsInBlock() : -1), (_dimInputsPresent ? getColsInBlock() : -1), -1);
            setLineNumbers(group4);
            Aggregate agg1 = new Aggregate(group4, HopsAgg2Lops.get(AggOp.SUM), getDataType(), getValueType(), ExecType.MR);
            agg1.getOutputParameters().setDimensions(_dim1, _dim2, (_dimInputsPresent ? getRowsInBlock() : -1), (_dimInputsPresent ? getColsInBlock() : -1), -1);
            setLineNumbers(agg1);
            // kahamSum is used for aggregation but inputs do not have
            // correction values
            agg1.setupCorrectionLocation(CorrectionLocationType.NONE);
            lctable = agg1;
        }
        setLops(lctable);
        // to introduce reblock lop since table itself outputs in blocked format if dims known.
        if (!dimsKnown() && !_dimInputsPresent) {
            setRequiresReblock(true);
        }
    }
}
Also used : Group(org.apache.sysml.lops.Group) Lop(org.apache.sysml.lops.Lop) Ctable(org.apache.sysml.lops.Ctable) DataType(org.apache.sysml.parser.Expression.DataType) ExecType(org.apache.sysml.lops.LopProperties.ExecType) Aggregate(org.apache.sysml.lops.Aggregate)

Example 3 with Ctable

use of org.apache.sysml.lops.Ctable in project systemml by apache.

the class CtableCPInstruction method processInstruction.

@Override
public void processInstruction(ExecutionContext ec) {
    MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName(), getExtendedOpcode());
    MatrixBlock matBlock2 = null, wtBlock = null;
    double cst1, cst2;
    CTableMap resultMap = new CTableMap(EntryType.INT);
    MatrixBlock resultBlock = null;
    Ctable.OperationTypes ctableOp = findCtableOperation();
    ctableOp = _isExpand ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ctableOp;
    long outputDim1 = (_dim1Literal ? (long) Double.parseDouble(_outDim1) : (ec.getScalarInput(_outDim1, ValueType.DOUBLE, false)).getLongValue());
    long outputDim2 = (_dim2Literal ? (long) Double.parseDouble(_outDim2) : (ec.getScalarInput(_outDim2, ValueType.DOUBLE, false)).getLongValue());
    boolean outputDimsKnown = (outputDim1 != -1 && outputDim2 != -1);
    if (outputDimsKnown) {
        int inputRows = matBlock1.getNumRows();
        int inputCols = matBlock1.getNumColumns();
        boolean sparse = MatrixBlock.evalSparseFormatInMemory(outputDim1, outputDim2, inputRows * inputCols);
        // blocks because it would implicitly turn the O(N) algorithm into O(N log N).
        if (!sparse)
            resultBlock = new MatrixBlock((int) outputDim1, (int) outputDim2, false);
    }
    if (_isExpand) {
        resultBlock = new MatrixBlock(matBlock1.getNumRows(), Integer.MAX_VALUE, true);
    }
    switch(ctableOp) {
        case // (VECTOR)
        CTABLE_TRANSFORM:
            // F=ctable(A,B,W)
            matBlock2 = ec.getMatrixInput(input2.getName(), getExtendedOpcode());
            wtBlock = ec.getMatrixInput(input3.getName(), getExtendedOpcode());
            matBlock1.ctableOperations((SimpleOperator) _optr, matBlock2, wtBlock, resultMap, resultBlock);
            break;
        case // (VECTOR/MATRIX)
        CTABLE_TRANSFORM_SCALAR_WEIGHT:
            // F = ctable(A,B) or F = ctable(A,B,1)
            matBlock2 = ec.getMatrixInput(input2.getName(), getExtendedOpcode());
            cst1 = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
            matBlock1.ctableOperations((SimpleOperator) _optr, matBlock2, cst1, _ignoreZeros, resultMap, resultBlock);
            break;
        case // (VECTOR)
        CTABLE_EXPAND_SCALAR_WEIGHT:
            // F = ctable(seq,A) or F = ctable(seq,B,1)
            matBlock2 = ec.getMatrixInput(input2.getName(), getExtendedOpcode());
            cst1 = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
            // only resultBlock.rlen known, resultBlock.clen set in operation
            matBlock1.ctableOperations((SimpleOperator) _optr, matBlock2, cst1, resultBlock);
            break;
        case // (VECTOR)
        CTABLE_TRANSFORM_HISTOGRAM:
            // F=ctable(A,1) or F = ctable(A,1,1)
            cst1 = ec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue();
            cst2 = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
            matBlock1.ctableOperations((SimpleOperator) _optr, cst1, cst2, resultMap, resultBlock);
            break;
        case // (VECTOR)
        CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM:
            // F=ctable(A,1,W)
            wtBlock = ec.getMatrixInput(input3.getName(), getExtendedOpcode());
            cst1 = ec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue();
            matBlock1.ctableOperations((SimpleOperator) _optr, cst1, wtBlock, resultMap, resultBlock);
            break;
        default:
            throw new DMLRuntimeException("Encountered an invalid ctable operation (" + ctableOp + ") while executing instruction: " + this.toString());
    }
    if (input1.getDataType() == DataType.MATRIX)
        ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
    if (input2.getDataType() == DataType.MATRIX)
        ec.releaseMatrixInput(input2.getName(), getExtendedOpcode());
    if (input3.getDataType() == DataType.MATRIX)
        ec.releaseMatrixInput(input3.getName(), getExtendedOpcode());
    if (resultBlock == null) {
        // decided for hash-aggregation just to prevent inefficiency in case of sparse outputs.
        if (outputDimsKnown)
            resultBlock = DataConverter.convertToMatrixBlock(resultMap, (int) outputDim1, (int) outputDim2);
        else
            resultBlock = DataConverter.convertToMatrixBlock(resultMap);
    } else
        resultBlock.examSparsity();
    // such as ctable expand (guarded by released input memory)
    if (checkGuardedRepresentationChange(matBlock1, matBlock2, resultBlock)) {
        resultBlock.examSparsity();
    }
    ec.setMatrixOutput(output.getName(), resultBlock, getExtendedOpcode());
}
Also used : MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) CTableMap(org.apache.sysml.runtime.matrix.data.CTableMap) Ctable(org.apache.sysml.lops.Ctable) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Example 4 with Ctable

use of org.apache.sysml.lops.Ctable in project incubator-systemml by apache.

the class TernaryOp method constructLopsCtable.

/**
 * Method to construct LOPs when op = CTABLE.
 */
private void constructLopsCtable() {
    if (_op != OpOp3.CTABLE)
        throw new HopsException("Unexpected operation: " + _op + ", expecting " + OpOp3.CTABLE);
    /*
		 * We must handle three different cases: case1 : all three
		 * inputs are vectors (e.g., F=ctable(A,B,W)) case2 : two
		 * vectors and one scalar (e.g., F=ctable(A,B)) case3 : one
		 * vector and two scalars (e.g., F=ctable(A))
		 */
    // identify the particular case
    // F=ctable(A,B,W)
    DataType dt1 = getInput().get(0).getDataType();
    DataType dt2 = getInput().get(1).getDataType();
    DataType dt3 = getInput().get(2).getDataType();
    Ctable.OperationTypes ternaryOpOrig = Ctable.findCtableOperationByInputDataTypes(dt1, dt2, dt3);
    // Compute lops for all inputs
    Lop[] inputLops = new Lop[getInput().size()];
    for (int i = 0; i < getInput().size(); i++) {
        inputLops[i] = getInput().get(i).constructLops();
    }
    ExecType et = optFindExecType();
    // reset reblock requirement (see MR ctable / construct lops)
    setRequiresReblock(false);
    if (et == ExecType.CP || et == ExecType.SPARK) {
        // for CP we support only ctable expand left
        Ctable.OperationTypes ternaryOp = isSequenceRewriteApplicable(true) ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ternaryOpOrig;
        boolean ignoreZeros = false;
        if (isMatrixIgnoreZeroRewriteApplicable()) {
            // table - rmempty - rshape
            ignoreZeros = true;
            inputLops[0] = ((ParameterizedBuiltinOp) getInput().get(0)).getTargetHop().getInput().get(0).constructLops();
            inputLops[1] = ((ParameterizedBuiltinOp) getInput().get(1)).getTargetHop().getInput().get(0).constructLops();
        }
        Ctable ternary = new Ctable(inputLops, ternaryOp, getDataType(), getValueType(), ignoreZeros, et);
        ternary.getOutputParameters().setDimensions(_dim1, _dim2, getRowsInBlock(), getColsInBlock(), -1);
        setLineNumbers(ternary);
        // force blocked output in CP (see below), otherwise binarycell
        if (et == ExecType.SPARK) {
            ternary.getOutputParameters().setDimensions(_dim1, _dim2, -1, -1, -1);
            setRequiresReblock(true);
        } else
            ternary.getOutputParameters().setDimensions(_dim1, _dim2, getRowsInBlock(), getColsInBlock(), -1);
        // ternary opt, w/o reblock in CP
        setLops(ternary);
    } else // MR
    {
        // for MR we support both ctable expand left and right
        Ctable.OperationTypes ternaryOp = isSequenceRewriteApplicable() ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ternaryOpOrig;
        Group group1 = null, group2 = null, group3 = null, group4 = null;
        group1 = new Group(inputLops[0], Group.OperationTypes.Sort, getDataType(), getValueType());
        group1.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        setLineNumbers(group1);
        Ctable ternary = null;
        // create "group" lops for MATRIX inputs
        switch(ternaryOp) {
            case CTABLE_TRANSFORM:
                // F = ctable(A,B,W)
                group2 = new Group(inputLops[1], Group.OperationTypes.Sort, getDataType(), getValueType());
                group2.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
                setLineNumbers(group2);
                group3 = new Group(inputLops[2], Group.OperationTypes.Sort, getDataType(), getValueType());
                group3.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
                setLineNumbers(group3);
                if (inputLops.length == 3)
                    ternary = new Ctable(new Lop[] { group1, group2, group3 }, ternaryOp, getDataType(), getValueType(), et);
                else
                    // output dimensions are given
                    ternary = new Ctable(new Lop[] { group1, group2, group3, inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
                break;
            case CTABLE_TRANSFORM_SCALAR_WEIGHT:
                // F = ctable(A,B) or F = ctable(A,B,1)
                group2 = new Group(inputLops[1], Group.OperationTypes.Sort, getDataType(), getValueType());
                group2.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
                setLineNumbers(group2);
                if (inputLops.length == 3)
                    ternary = new Ctable(new Lop[] { group1, group2, inputLops[2] }, ternaryOp, getDataType(), getValueType(), et);
                else
                    ternary = new Ctable(new Lop[] { group1, group2, inputLops[2], inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
                break;
            case CTABLE_EXPAND_SCALAR_WEIGHT:
                // F=ctable(seq(1,N),A) or F = ctable(seq,A,1)
                // left 1, right 0 (index of input data)
                int left = isSequenceRewriteApplicable(true) ? 1 : 0;
                Group group = new Group(getInput().get(left).constructLops(), Group.OperationTypes.Sort, getDataType(), getValueType());
                group.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
                if (inputLops.length == 3)
                    ternary = new Ctable(new Lop[] { // matrix
                    group, // weight
                    getInput().get(2).constructLops(), // left
                    new LiteralOp(left).constructLops() }, ternaryOp, getDataType(), getValueType(), et);
                else
                    ternary = new Ctable(new Lop[] { // matrix
                    group, // weight
                    getInput().get(2).constructLops(), // left
                    new LiteralOp(left).constructLops(), inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
                break;
            case CTABLE_TRANSFORM_HISTOGRAM:
                // F=ctable(A,1) or F = ctable(A,1,1)
                if (inputLops.length == 3)
                    ternary = new Ctable(new Lop[] { group1, getInput().get(1).constructLops(), getInput().get(2).constructLops() }, ternaryOp, getDataType(), getValueType(), et);
                else
                    ternary = new Ctable(new Lop[] { group1, getInput().get(1).constructLops(), getInput().get(2).constructLops(), inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
                break;
            case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM:
                // F=ctable(A,1,W)
                group3 = new Group(getInput().get(2).constructLops(), Group.OperationTypes.Sort, getDataType(), getValueType());
                group3.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
                setLineNumbers(group3);
                if (inputLops.length == 3)
                    ternary = new Ctable(new Lop[] { group1, getInput().get(1).constructLops(), group3 }, ternaryOp, getDataType(), getValueType(), et);
                else
                    ternary = new Ctable(new Lop[] { group1, getInput().get(1).constructLops(), group3, inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
                break;
            default:
                throw new HopsException("Invalid ternary operator type: " + _op);
        }
        // output dimensions are not known at compilation time
        ternary.getOutputParameters().setDimensions(_dim1, _dim2, (_dimInputsPresent ? getRowsInBlock() : -1), (_dimInputsPresent ? getColsInBlock() : -1), -1);
        setLineNumbers(ternary);
        Lop lctable = ternary;
        if (!(_disjointInputs || ternaryOp == Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT)) {
            // no need for aggregation if (1) input indexed disjoint	or one side is sequence	w/ 1 increment
            group4 = new Group(ternary, Group.OperationTypes.Sort, getDataType(), getValueType());
            group4.getOutputParameters().setDimensions(_dim1, _dim2, (_dimInputsPresent ? getRowsInBlock() : -1), (_dimInputsPresent ? getColsInBlock() : -1), -1);
            setLineNumbers(group4);
            Aggregate agg1 = new Aggregate(group4, HopsAgg2Lops.get(AggOp.SUM), getDataType(), getValueType(), ExecType.MR);
            agg1.getOutputParameters().setDimensions(_dim1, _dim2, (_dimInputsPresent ? getRowsInBlock() : -1), (_dimInputsPresent ? getColsInBlock() : -1), -1);
            setLineNumbers(agg1);
            // kahamSum is used for aggregation but inputs do not have
            // correction values
            agg1.setupCorrectionLocation(CorrectionLocationType.NONE);
            lctable = agg1;
        }
        setLops(lctable);
        // to introduce reblock lop since table itself outputs in blocked format if dims known.
        if (!dimsKnown() && !_dimInputsPresent) {
            setRequiresReblock(true);
        }
    }
}
Also used : Group(org.apache.sysml.lops.Group) Lop(org.apache.sysml.lops.Lop) Ctable(org.apache.sysml.lops.Ctable) DataType(org.apache.sysml.parser.Expression.DataType) ExecType(org.apache.sysml.lops.LopProperties.ExecType) Aggregate(org.apache.sysml.lops.Aggregate)

Example 5 with Ctable

use of org.apache.sysml.lops.Ctable in project incubator-systemml by apache.

the class CtableSPInstruction method processInstruction.

@Override
public void processInstruction(ExecutionContext ec) {
    SparkExecutionContext sec = (SparkExecutionContext) ec;
    // get input rdd handle
    JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(input1.getName());
    JavaPairRDD<MatrixIndexes, MatrixBlock> in2 = null;
    JavaPairRDD<MatrixIndexes, MatrixBlock> in3 = null;
    double scalar_input2 = -1, scalar_input3 = -1;
    Ctable.OperationTypes ctableOp = Ctable.findCtableOperationByInputDataTypes(input1.getDataType(), input2.getDataType(), input3.getDataType());
    ctableOp = _isExpand ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ctableOp;
    MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(input1.getName());
    MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
    // First get the block sizes and then set them as -1 to allow for binary cell reblock
    int brlen = mc1.getRowsPerBlock();
    int bclen = mc1.getColsPerBlock();
    JavaPairRDD<MatrixIndexes, ArrayList<MatrixBlock>> inputMBs = null;
    JavaPairRDD<MatrixIndexes, CTableMap> ctables = null;
    JavaPairRDD<MatrixIndexes, Double> bincellsNoFilter = null;
    boolean setLineage2 = false;
    boolean setLineage3 = false;
    switch(ctableOp) {
        case // (VECTOR)
        CTABLE_TRANSFORM:
            // F=ctable(A,B,W)
            in2 = sec.getBinaryBlockRDDHandleForVariable(input2.getName());
            in3 = sec.getBinaryBlockRDDHandleForVariable(input3.getName());
            setLineage2 = true;
            setLineage3 = true;
            inputMBs = in1.cogroup(in2).cogroup(in3).mapToPair(new MapThreeMBIterableIntoAL());
            ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2, scalar_input3, this.instString, (SimpleOperator) _optr, _ignoreZeros));
            break;
        case // (VECTOR)
        CTABLE_EXPAND_SCALAR_WEIGHT:
            // F = ctable(seq,A) or F = ctable(seq,B,1)
            scalar_input3 = sec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
            if (scalar_input3 == 1) {
                in2 = sec.getBinaryBlockRDDHandleForVariable(input2.getName());
                setLineage2 = true;
                bincellsNoFilter = in2.flatMapToPair(new ExpandScalarCtableOperation(brlen));
                break;
            }
        case // (VECTOR/MATRIX)
        CTABLE_TRANSFORM_SCALAR_WEIGHT:
            // F = ctable(A,B) or F = ctable(A,B,1)
            in2 = sec.getBinaryBlockRDDHandleForVariable(input2.getName());
            setLineage2 = true;
            scalar_input3 = sec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
            inputMBs = in1.cogroup(in2).mapToPair(new MapTwoMBIterableIntoAL());
            ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2, scalar_input3, this.instString, (SimpleOperator) _optr, _ignoreZeros));
            break;
        case // (VECTOR)
        CTABLE_TRANSFORM_HISTOGRAM:
            // F=ctable(A,1) or F = ctable(A,1,1)
            scalar_input2 = sec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue();
            scalar_input3 = sec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
            inputMBs = in1.mapToPair(new MapMBIntoAL());
            ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2, scalar_input3, this.instString, (SimpleOperator) _optr, _ignoreZeros));
            break;
        case // (VECTOR)
        CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM:
            // F=ctable(A,1,W)
            in3 = sec.getBinaryBlockRDDHandleForVariable(input3.getName());
            setLineage3 = true;
            scalar_input2 = sec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue();
            inputMBs = in1.cogroup(in3).mapToPair(new MapTwoMBIterableIntoAL());
            ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2, scalar_input3, this.instString, (SimpleOperator) _optr, _ignoreZeros));
            break;
        default:
            throw new DMLRuntimeException("Encountered an invalid ctable operation (" + ctableOp + ") while executing instruction: " + this.toString());
    }
    // Now perform aggregation on ctables to get binaryCells
    if (bincellsNoFilter == null && ctables != null) {
        bincellsNoFilter = ctables.values().flatMapToPair(new ExtractBinaryCellsFromCTable());
        bincellsNoFilter = RDDAggregateUtils.sumCellsByKeyStable(bincellsNoFilter);
    } else if (!(bincellsNoFilter != null && ctables == null)) {
        throw new DMLRuntimeException("Incorrect ctable operation");
    }
    // handle known/unknown dimensions
    long outputDim1 = (_dim1Literal ? (long) Double.parseDouble(_outDim1) : (sec.getScalarInput(_outDim1, ValueType.DOUBLE, false)).getLongValue());
    long outputDim2 = (_dim2Literal ? (long) Double.parseDouble(_outDim2) : (sec.getScalarInput(_outDim2, ValueType.DOUBLE, false)).getLongValue());
    MatrixCharacteristics mcBinaryCells = null;
    boolean findDimensions = (outputDim1 == -1 && outputDim2 == -1);
    if (!findDimensions) {
        if ((outputDim1 == -1 && outputDim2 != -1) || (outputDim1 != -1 && outputDim2 == -1))
            throw new DMLRuntimeException("Incorrect output dimensions passed to TernarySPInstruction:" + outputDim1 + " " + outputDim2);
        else
            mcBinaryCells = new MatrixCharacteristics(outputDim1, outputDim2, brlen, bclen);
        // filtering according to given dimensions
        bincellsNoFilter = bincellsNoFilter.filter(new FilterCells(mcBinaryCells.getRows(), mcBinaryCells.getCols()));
    }
    // convert double values to matrix cell
    JavaPairRDD<MatrixIndexes, MatrixCell> binaryCells = bincellsNoFilter.mapToPair(new ConvertToBinaryCell());
    // find dimensions if necessary (w/ cache for reblock)
    if (findDimensions) {
        binaryCells = SparkUtils.cacheBinaryCellRDD(binaryCells);
        mcBinaryCells = SparkUtils.computeMatrixCharacteristics(binaryCells);
    }
    // store output rdd handle
    sec.setRDDHandleForVariable(output.getName(), binaryCells);
    mcOut.set(mcBinaryCells);
    // Since we are outputing binary cells, we set block sizes = -1
    mcOut.setRowsPerBlock(-1);
    mcOut.setColsPerBlock(-1);
    sec.addLineageRDD(output.getName(), input1.getName());
    if (setLineage2)
        sec.addLineageRDD(output.getName(), input2.getName());
    if (setLineage3)
        sec.addLineageRDD(output.getName(), input3.getName());
}
Also used : MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) ArrayList(java.util.ArrayList) Ctable(org.apache.sysml.lops.Ctable) MatrixCell(org.apache.sysml.runtime.matrix.data.MatrixCell) SparkExecutionContext(org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) CTableMap(org.apache.sysml.runtime.matrix.data.CTableMap)

Aggregations

Ctable (org.apache.sysml.lops.Ctable)6 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)4 CTableMap (org.apache.sysml.runtime.matrix.data.CTableMap)4 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)4 ArrayList (java.util.ArrayList)2 Aggregate (org.apache.sysml.lops.Aggregate)2 Group (org.apache.sysml.lops.Group)2 Lop (org.apache.sysml.lops.Lop)2 ExecType (org.apache.sysml.lops.LopProperties.ExecType)2 DataType (org.apache.sysml.parser.Expression.DataType)2 SparkExecutionContext (org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext)2 MatrixCharacteristics (org.apache.sysml.runtime.matrix.MatrixCharacteristics)2 MatrixCell (org.apache.sysml.runtime.matrix.data.MatrixCell)2 MatrixIndexes (org.apache.sysml.runtime.matrix.data.MatrixIndexes)2