Search in sources :

Example 36 with DataType

use of org.apache.sysml.parser.Expression.DataType in project incubator-systemml by apache.

the class ArithmeticBinarySPInstruction method parseInstruction.

public static ArithmeticBinarySPInstruction parseInstruction(String str) throws DMLRuntimeException {
    CPOperand in1 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
    CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
    CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
    String opcode = null;
    boolean isBroadcast = false;
    VectorType vtype = null;
    if (str.startsWith("SPARK" + Lop.OPERAND_DELIMITOR + "map")) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(parts, 5);
        opcode = parts[0];
        in1.split(parts[1]);
        in2.split(parts[2]);
        out.split(parts[3]);
        vtype = VectorType.valueOf(parts[5]);
        isBroadcast = true;
    } else {
        opcode = parseBinaryInstruction(str, in1, in2, out);
    }
    // Arithmetic operations must be performed on DOUBLE or INT
    DataType dt1 = in1.getDataType();
    DataType dt2 = in2.getDataType();
    Operator operator = (dt1 != dt2) ? InstructionUtils.parseScalarBinaryOperator(opcode, (dt1 == DataType.SCALAR)) : InstructionUtils.parseExtendedBinaryOperator(opcode);
    if (dt1 == DataType.MATRIX || dt2 == DataType.MATRIX) {
        if (dt1 == DataType.MATRIX && dt2 == DataType.MATRIX) {
            if (isBroadcast)
                return new MatrixBVectorArithmeticSPInstruction(operator, in1, in2, out, vtype, opcode, str);
            else
                return new MatrixMatrixArithmeticSPInstruction(operator, in1, in2, out, opcode, str);
        } else
            return new MatrixScalarArithmeticSPInstruction(operator, in1, in2, out, opcode, str);
    }
    return null;
}
Also used : Operator(org.apache.sysml.runtime.matrix.operators.Operator) VectorType(org.apache.sysml.lops.BinaryM.VectorType) DataType(org.apache.sysml.parser.Expression.DataType) CPOperand(org.apache.sysml.runtime.instructions.cp.CPOperand)

Example 37 with DataType

use of org.apache.sysml.parser.Expression.DataType in project incubator-systemml by apache.

the class BinaryOp method inferOutputCharacteristics.

@Override
protected long[] inferOutputCharacteristics(MemoTable memo) {
    long[] ret = null;
    MatrixCharacteristics[] mc = memo.getAllInputStats(getInput());
    Hop input1 = getInput().get(0);
    Hop input2 = getInput().get(1);
    DataType dt1 = input1.getDataType();
    DataType dt2 = input2.getDataType();
    if (op == OpOp2.CBIND) {
        long ldim1 = -1, ldim2 = -1, lnnz = -1;
        if (mc[0].rowsKnown() || mc[1].rowsKnown())
            ldim1 = mc[0].rowsKnown() ? mc[0].getRows() : mc[1].getRows();
        if (mc[0].colsKnown() && mc[1].colsKnown())
            ldim2 = mc[0].getCols() + mc[1].getCols();
        if (mc[0].nnzKnown() && mc[1].nnzKnown())
            lnnz = mc[0].getNonZeros() + mc[1].getNonZeros();
        if (ldim1 >= 0 || ldim2 >= 0 || lnnz >= 0)
            return new long[] { ldim1, ldim2, lnnz };
    } else if (op == OpOp2.RBIND) {
        long ldim1 = -1, ldim2 = -1, lnnz = -1;
        if (mc[0].colsKnown() || mc[1].colsKnown())
            ldim2 = mc[0].colsKnown() ? mc[0].getCols() : mc[1].getCols();
        if (mc[0].rowsKnown() && mc[1].rowsKnown())
            ldim1 = mc[0].getRows() + mc[1].getRows();
        if (mc[0].nnzKnown() && mc[1].nnzKnown())
            lnnz = mc[0].getNonZeros() + mc[1].getNonZeros();
        if (ldim1 >= 0 || ldim2 >= 0 || lnnz >= 0)
            return new long[] { ldim1, ldim2, lnnz };
    } else if (op == OpOp2.SOLVE) {
        // Output is a (likely to be dense) vector of size number of columns in the first input
        if (mc[0].getCols() >= 0) {
            ret = new long[] { mc[0].getCols(), 1, mc[0].getCols() };
        }
    } else // general case
    {
        long ldim1, ldim2;
        double sp1 = 1.0, sp2 = 1.0;
        if (dt1 == DataType.MATRIX && dt2 == DataType.SCALAR && mc[0].dimsKnown()) {
            ldim1 = mc[0].getRows();
            ldim2 = mc[0].getCols();
            sp1 = (mc[0].getNonZeros() > 0) ? OptimizerUtils.getSparsity(ldim1, ldim2, mc[0].getNonZeros()) : 1.0;
        } else if (dt1 == DataType.SCALAR && dt2 == DataType.MATRIX) {
            ldim1 = mc[1].getRows();
            ldim2 = mc[1].getCols();
            sp2 = (mc[1].getNonZeros() > 0) ? OptimizerUtils.getSparsity(ldim1, ldim2, mc[1].getNonZeros()) : 1.0;
        } else // MATRIX - MATRIX
        {
            // for cols we need to be careful with regard to matrix-vector operations
            if (// OUTER VECTOR OPERATION
            outer) {
                ldim1 = mc[0].getRows();
                ldim2 = mc[1].getCols();
            } else // GENERAL CASE
            {
                ldim1 = (mc[0].rowsKnown()) ? mc[0].getRows() : (mc[1].getRows() > 1) ? mc[1].getRows() : -1;
                ldim2 = (mc[0].colsKnown()) ? mc[0].getCols() : (mc[1].getCols() > 1) ? mc[1].getCols() : -1;
            }
            sp1 = (mc[0].getNonZeros() > 0) ? OptimizerUtils.getSparsity(ldim1, ldim2, mc[0].getNonZeros()) : 1.0;
            sp2 = (mc[1].getNonZeros() > 0) ? OptimizerUtils.getSparsity(ldim1, ldim2, mc[1].getNonZeros()) : 1.0;
        }
        if (ldim1 >= 0 && ldim2 >= 0) {
            if (OptimizerUtils.isBinaryOpConditionalSparseSafe(op) && input2 instanceof LiteralOp) {
                long lnnz = (long) (ldim1 * ldim2 * OptimizerUtils.getBinaryOpSparsityConditionalSparseSafe(sp1, op, (LiteralOp) input2));
                ret = new long[] { ldim1, ldim2, lnnz };
            } else {
                // sparsity estimates are conservative in terms of the worstcase behavior, however,
                // for outer vector operations the average case is equivalent to the worst case.
                long lnnz = (long) (ldim1 * ldim2 * OptimizerUtils.getBinaryOpSparsity(sp1, sp2, op, !outer));
                ret = new long[] { ldim1, ldim2, lnnz };
            }
        }
    }
    return ret;
}
Also used : DataType(org.apache.sysml.parser.Expression.DataType) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics)

Example 38 with DataType

use of org.apache.sysml.parser.Expression.DataType in project incubator-systemml by apache.

the class BinaryOp method constructLopsBinaryDefault.

private void constructLopsBinaryDefault() {
    /* Default behavior for BinaryOp */
    // it depends on input data types
    DataType dt1 = getInput().get(0).getDataType();
    DataType dt2 = getInput().get(1).getDataType();
    if (dt1 == dt2 && dt1 == DataType.SCALAR) {
        // Both operands scalar
        BinaryScalar binScalar1 = new BinaryScalar(getInput().get(0).constructLops(), getInput().get(1).constructLops(), HopsOpOp2LopsBS.get(op), getDataType(), getValueType());
        binScalar1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
        setLineNumbers(binScalar1);
        setLops(binScalar1);
    } else if ((dt1 == DataType.MATRIX && dt2 == DataType.SCALAR) || (dt1 == DataType.SCALAR && dt2 == DataType.MATRIX)) {
        // One operand is Matrix and the other is scalar
        ExecType et = optFindExecType();
        // select specific operator implementations
        Unary.OperationTypes ot = null;
        Hop right = getInput().get(1);
        if (op == OpOp2.POW && right instanceof LiteralOp && ((LiteralOp) right).getDoubleValue() == 2.0)
            ot = Unary.OperationTypes.POW2;
        else if (op == OpOp2.MULT && right instanceof LiteralOp && ((LiteralOp) right).getDoubleValue() == 2.0)
            ot = Unary.OperationTypes.MULTIPLY2;
        else
            // general case
            ot = HopsOpOp2LopsU.get(op);
        Unary unary1 = new Unary(getInput().get(0).constructLops(), getInput().get(1).constructLops(), ot, getDataType(), getValueType(), et);
        setOutputDimensions(unary1);
        setLineNumbers(unary1);
        setLops(unary1);
    } else {
        // Both operands are Matrixes
        ExecType et = optFindExecType();
        boolean isGPUSoftmax = et == ExecType.GPU && op == Hop.OpOp2.DIV && getInput().get(0) instanceof UnaryOp && getInput().get(1) instanceof AggUnaryOp && ((UnaryOp) getInput().get(0)).getOp() == OpOp1.EXP && ((AggUnaryOp) getInput().get(1)).getOp() == AggOp.SUM && ((AggUnaryOp) getInput().get(1)).getDirection() == Direction.Row && getInput().get(0) == getInput().get(1).getInput().get(0);
        if (isGPUSoftmax) {
            UnaryCP softmax = new UnaryCP(getInput().get(0).getInput().get(0).constructLops(), UnaryCP.OperationTypes.SOFTMAX, getDataType(), getValueType(), et);
            setOutputDimensions(softmax);
            setLineNumbers(softmax);
            setLops(softmax);
        } else if (et == ExecType.CP || et == ExecType.GPU) {
            Lop binary = null;
            boolean isLeftXGt = (getInput().get(0) instanceof BinaryOp) && ((BinaryOp) getInput().get(0)).getOp() == OpOp2.GREATER;
            Hop potentialZero = isLeftXGt ? ((BinaryOp) getInput().get(0)).getInput().get(1) : null;
            boolean isLeftXGt0 = isLeftXGt && potentialZero != null && potentialZero instanceof LiteralOp && ((LiteralOp) potentialZero).getDoubleValue() == 0;
            if (op == OpOp2.MULT && isLeftXGt0 && !getInput().get(0).isVector() && !getInput().get(1).isVector() && getInput().get(0).dimsKnown() && getInput().get(1).dimsKnown()) {
                binary = new ConvolutionTransform(getInput().get(0).getInput().get(0).constructLops(), getInput().get(1).constructLops(), ConvolutionTransform.OperationTypes.RELU_BACKWARD, getDataType(), getValueType(), et, -1);
            } else
                binary = new Binary(getInput().get(0).constructLops(), getInput().get(1).constructLops(), HopsOpOp2LopsB.get(op), getDataType(), getValueType(), et);
            setOutputDimensions(binary);
            setLineNumbers(binary);
            setLops(binary);
        } else if (et == ExecType.SPARK) {
            Hop left = getInput().get(0);
            Hop right = getInput().get(1);
            MMBinaryMethod mbin = optFindMMBinaryMethodSpark(left, right);
            Lop binary = null;
            if (mbin == MMBinaryMethod.MR_BINARY_UAGG_CHAIN) {
                AggUnaryOp uRight = (AggUnaryOp) right;
                binary = new BinaryUAggChain(left.constructLops(), HopsOpOp2LopsB.get(op), HopsAgg2Lops.get(uRight.getOp()), HopsDirection2Lops.get(uRight.getDirection()), getDataType(), getValueType(), et);
            } else if (mbin == MMBinaryMethod.MR_BINARY_M) {
                boolean partitioned = false;
                boolean isColVector = (right.getDim2() == 1 && left.getDim1() == right.getDim1());
                binary = new BinaryM(left.constructLops(), right.constructLops(), HopsOpOp2LopsB.get(op), getDataType(), getValueType(), et, partitioned, isColVector);
            } else {
                binary = new Binary(left.constructLops(), right.constructLops(), HopsOpOp2LopsB.get(op), getDataType(), getValueType(), et);
            }
            setOutputDimensions(binary);
            setLineNumbers(binary);
            setLops(binary);
        } else // MR
        {
            Hop left = getInput().get(0);
            Hop right = getInput().get(1);
            MMBinaryMethod mbin = optFindMMBinaryMethod(left, right);
            if (mbin == MMBinaryMethod.MR_BINARY_M) {
                boolean needPart = requiresPartitioning(right);
                Lop dcInput = right.constructLops();
                if (needPart) {
                    // right side in distributed cache
                    ExecType etPart = (OptimizerUtils.estimateSizeExactSparsity(right.getDim1(), right.getDim2(), OptimizerUtils.getSparsity(right.getDim1(), right.getDim2(), right.getNnz())) < OptimizerUtils.getLocalMemBudget()) ? ExecType.CP : // operator selection
                    ExecType.MR;
                    dcInput = new DataPartition(dcInput, DataType.MATRIX, ValueType.DOUBLE, etPart, (right.getDim2() == 1) ? PDataPartitionFormat.ROW_BLOCK_WISE_N : PDataPartitionFormat.COLUMN_BLOCK_WISE_N);
                    dcInput.getOutputParameters().setDimensions(right.getDim1(), right.getDim2(), right.getRowsInBlock(), right.getColsInBlock(), right.getNnz());
                    dcInput.setAllPositions(right.getFilename(), right.getBeginLine(), right.getBeginColumn(), right.getEndLine(), right.getEndColumn());
                }
                BinaryM binary = new BinaryM(left.constructLops(), dcInput, HopsOpOp2LopsB.get(op), getDataType(), getValueType(), ExecType.MR, needPart, (right.getDim2() == 1 && left.getDim1() == right.getDim1()));
                setOutputDimensions(binary);
                setLineNumbers(binary);
                setLops(binary);
            } else if (mbin == MMBinaryMethod.MR_BINARY_UAGG_CHAIN) {
                AggUnaryOp uRight = (AggUnaryOp) right;
                BinaryUAggChain bin = new BinaryUAggChain(left.constructLops(), HopsOpOp2LopsB.get(op), HopsAgg2Lops.get(uRight.getOp()), HopsDirection2Lops.get(uRight.getDirection()), getDataType(), getValueType(), et);
                setOutputDimensions(bin);
                setLineNumbers(bin);
                setLops(bin);
            } else if (mbin == MMBinaryMethod.MR_BINARY_OUTER_R) {
                boolean requiresRepLeft = (!right.dimsKnown() || right.getDim2() > right.getColsInBlock());
                boolean requiresRepRight = (!left.dimsKnown() || left.getDim1() > right.getRowsInBlock());
                Lop leftLop = left.constructLops();
                Lop rightLop = right.constructLops();
                if (requiresRepLeft) {
                    // ncol of right determines rep of left
                    Lop offset = createOffsetLop(right, true);
                    leftLop = new RepMat(leftLop, offset, true, left.getDataType(), left.getValueType());
                    setOutputDimensions(leftLop);
                    setLineNumbers(leftLop);
                }
                if (requiresRepRight) {
                    // nrow of right determines rep of right
                    Lop offset = createOffsetLop(left, false);
                    rightLop = new RepMat(rightLop, offset, false, right.getDataType(), right.getValueType());
                    setOutputDimensions(rightLop);
                    setLineNumbers(rightLop);
                }
                Group group1 = new Group(leftLop, Group.OperationTypes.Sort, getDataType(), getValueType());
                setLineNumbers(group1);
                setOutputDimensions(group1);
                Group group2 = new Group(rightLop, Group.OperationTypes.Sort, getDataType(), getValueType());
                setLineNumbers(group2);
                setOutputDimensions(group2);
                Binary binary = new Binary(group1, group2, HopsOpOp2LopsB.get(op), getDataType(), getValueType(), et);
                setOutputDimensions(binary);
                setLineNumbers(binary);
                setLops(binary);
            } else // MMBinaryMethod.MR_BINARY_R
            {
                boolean requiresRep = requiresReplication(left, right);
                Lop rightLop = right.constructLops();
                if (requiresRep) {
                    // ncol of left input (determines num replicates)
                    Lop offset = createOffsetLop(left, (right.getDim2() <= 1));
                    rightLop = new RepMat(rightLop, offset, (right.getDim2() <= 1), right.getDataType(), right.getValueType());
                    setOutputDimensions(rightLop);
                    setLineNumbers(rightLop);
                }
                Group group1 = new Group(getInput().get(0).constructLops(), Group.OperationTypes.Sort, getDataType(), getValueType());
                setLineNumbers(group1);
                setOutputDimensions(group1);
                Group group2 = new Group(rightLop, Group.OperationTypes.Sort, getDataType(), getValueType());
                setLineNumbers(group2);
                setOutputDimensions(group2);
                Binary binary = new Binary(group1, group2, HopsOpOp2LopsB.get(op), getDataType(), getValueType(), et);
                setLineNumbers(binary);
                setOutputDimensions(binary);
                setLops(binary);
            }
        }
    }
}
Also used : Group(org.apache.sysml.lops.Group) BinaryM(org.apache.sysml.lops.BinaryM) BinaryUAggChain(org.apache.sysml.lops.BinaryUAggChain) Lop(org.apache.sysml.lops.Lop) Unary(org.apache.sysml.lops.Unary) CombineUnary(org.apache.sysml.lops.CombineUnary) UnaryCP(org.apache.sysml.lops.UnaryCP) RepMat(org.apache.sysml.lops.RepMat) OperationTypes(org.apache.sysml.lops.CombineBinary.OperationTypes) DataType(org.apache.sysml.parser.Expression.DataType) ExecType(org.apache.sysml.lops.LopProperties.ExecType) Binary(org.apache.sysml.lops.Binary) CombineBinary(org.apache.sysml.lops.CombineBinary) BinaryScalar(org.apache.sysml.lops.BinaryScalar) ConvolutionTransform(org.apache.sysml.lops.ConvolutionTransform) DataPartition(org.apache.sysml.lops.DataPartition)

Example 39 with DataType

use of org.apache.sysml.parser.Expression.DataType in project incubator-systemml by apache.

the class TernaryOp method constructLopsCtable.

/**
 * Method to construct LOPs when op = CTABLE.
 */
private void constructLopsCtable() {
    if (_op != OpOp3.CTABLE)
        throw new HopsException("Unexpected operation: " + _op + ", expecting " + OpOp3.CTABLE);
    /*
		 * We must handle three different cases: case1 : all three
		 * inputs are vectors (e.g., F=ctable(A,B,W)) case2 : two
		 * vectors and one scalar (e.g., F=ctable(A,B)) case3 : one
		 * vector and two scalars (e.g., F=ctable(A))
		 */
    // identify the particular case
    // F=ctable(A,B,W)
    DataType dt1 = getInput().get(0).getDataType();
    DataType dt2 = getInput().get(1).getDataType();
    DataType dt3 = getInput().get(2).getDataType();
    Ctable.OperationTypes ternaryOpOrig = Ctable.findCtableOperationByInputDataTypes(dt1, dt2, dt3);
    // Compute lops for all inputs
    Lop[] inputLops = new Lop[getInput().size()];
    for (int i = 0; i < getInput().size(); i++) {
        inputLops[i] = getInput().get(i).constructLops();
    }
    ExecType et = optFindExecType();
    // reset reblock requirement (see MR ctable / construct lops)
    setRequiresReblock(false);
    if (et == ExecType.CP || et == ExecType.SPARK) {
        // for CP we support only ctable expand left
        Ctable.OperationTypes ternaryOp = isSequenceRewriteApplicable(true) ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ternaryOpOrig;
        boolean ignoreZeros = false;
        if (isMatrixIgnoreZeroRewriteApplicable()) {
            // table - rmempty - rshape
            ignoreZeros = true;
            inputLops[0] = ((ParameterizedBuiltinOp) getInput().get(0)).getTargetHop().getInput().get(0).constructLops();
            inputLops[1] = ((ParameterizedBuiltinOp) getInput().get(1)).getTargetHop().getInput().get(0).constructLops();
        }
        Ctable ternary = new Ctable(inputLops, ternaryOp, getDataType(), getValueType(), ignoreZeros, et);
        ternary.getOutputParameters().setDimensions(_dim1, _dim2, getRowsInBlock(), getColsInBlock(), -1);
        setLineNumbers(ternary);
        // force blocked output in CP (see below), otherwise binarycell
        if (et == ExecType.SPARK) {
            ternary.getOutputParameters().setDimensions(_dim1, _dim2, -1, -1, -1);
            setRequiresReblock(true);
        } else
            ternary.getOutputParameters().setDimensions(_dim1, _dim2, getRowsInBlock(), getColsInBlock(), -1);
        // ternary opt, w/o reblock in CP
        setLops(ternary);
    } else // MR
    {
        // for MR we support both ctable expand left and right
        Ctable.OperationTypes ternaryOp = isSequenceRewriteApplicable() ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ternaryOpOrig;
        Group group1 = null, group2 = null, group3 = null, group4 = null;
        group1 = new Group(inputLops[0], Group.OperationTypes.Sort, getDataType(), getValueType());
        group1.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        setLineNumbers(group1);
        Ctable ternary = null;
        // create "group" lops for MATRIX inputs
        switch(ternaryOp) {
            case CTABLE_TRANSFORM:
                // F = ctable(A,B,W)
                group2 = new Group(inputLops[1], Group.OperationTypes.Sort, getDataType(), getValueType());
                group2.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
                setLineNumbers(group2);
                group3 = new Group(inputLops[2], Group.OperationTypes.Sort, getDataType(), getValueType());
                group3.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
                setLineNumbers(group3);
                if (inputLops.length == 3)
                    ternary = new Ctable(new Lop[] { group1, group2, group3 }, ternaryOp, getDataType(), getValueType(), et);
                else
                    // output dimensions are given
                    ternary = new Ctable(new Lop[] { group1, group2, group3, inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
                break;
            case CTABLE_TRANSFORM_SCALAR_WEIGHT:
                // F = ctable(A,B) or F = ctable(A,B,1)
                group2 = new Group(inputLops[1], Group.OperationTypes.Sort, getDataType(), getValueType());
                group2.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
                setLineNumbers(group2);
                if (inputLops.length == 3)
                    ternary = new Ctable(new Lop[] { group1, group2, inputLops[2] }, ternaryOp, getDataType(), getValueType(), et);
                else
                    ternary = new Ctable(new Lop[] { group1, group2, inputLops[2], inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
                break;
            case CTABLE_EXPAND_SCALAR_WEIGHT:
                // F=ctable(seq(1,N),A) or F = ctable(seq,A,1)
                // left 1, right 0 (index of input data)
                int left = isSequenceRewriteApplicable(true) ? 1 : 0;
                Group group = new Group(getInput().get(left).constructLops(), Group.OperationTypes.Sort, getDataType(), getValueType());
                group.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
                if (inputLops.length == 3)
                    ternary = new Ctable(new Lop[] { // matrix
                    group, // weight
                    getInput().get(2).constructLops(), // left
                    new LiteralOp(left).constructLops() }, ternaryOp, getDataType(), getValueType(), et);
                else
                    ternary = new Ctable(new Lop[] { // matrix
                    group, // weight
                    getInput().get(2).constructLops(), // left
                    new LiteralOp(left).constructLops(), inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
                break;
            case CTABLE_TRANSFORM_HISTOGRAM:
                // F=ctable(A,1) or F = ctable(A,1,1)
                if (inputLops.length == 3)
                    ternary = new Ctable(new Lop[] { group1, getInput().get(1).constructLops(), getInput().get(2).constructLops() }, ternaryOp, getDataType(), getValueType(), et);
                else
                    ternary = new Ctable(new Lop[] { group1, getInput().get(1).constructLops(), getInput().get(2).constructLops(), inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
                break;
            case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM:
                // F=ctable(A,1,W)
                group3 = new Group(getInput().get(2).constructLops(), Group.OperationTypes.Sort, getDataType(), getValueType());
                group3.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
                setLineNumbers(group3);
                if (inputLops.length == 3)
                    ternary = new Ctable(new Lop[] { group1, getInput().get(1).constructLops(), group3 }, ternaryOp, getDataType(), getValueType(), et);
                else
                    ternary = new Ctable(new Lop[] { group1, getInput().get(1).constructLops(), group3, inputLops[3], inputLops[4] }, ternaryOp, getDataType(), getValueType(), et);
                break;
            default:
                throw new HopsException("Invalid ternary operator type: " + _op);
        }
        // output dimensions are not known at compilation time
        ternary.getOutputParameters().setDimensions(_dim1, _dim2, (_dimInputsPresent ? getRowsInBlock() : -1), (_dimInputsPresent ? getColsInBlock() : -1), -1);
        setLineNumbers(ternary);
        Lop lctable = ternary;
        if (!(_disjointInputs || ternaryOp == Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT)) {
            // no need for aggregation if (1) input indexed disjoint	or one side is sequence	w/ 1 increment
            group4 = new Group(ternary, Group.OperationTypes.Sort, getDataType(), getValueType());
            group4.getOutputParameters().setDimensions(_dim1, _dim2, (_dimInputsPresent ? getRowsInBlock() : -1), (_dimInputsPresent ? getColsInBlock() : -1), -1);
            setLineNumbers(group4);
            Aggregate agg1 = new Aggregate(group4, HopsAgg2Lops.get(AggOp.SUM), getDataType(), getValueType(), ExecType.MR);
            agg1.getOutputParameters().setDimensions(_dim1, _dim2, (_dimInputsPresent ? getRowsInBlock() : -1), (_dimInputsPresent ? getColsInBlock() : -1), -1);
            setLineNumbers(agg1);
            // kahamSum is used for aggregation but inputs do not have
            // correction values
            agg1.setupCorrectionLocation(CorrectionLocationType.NONE);
            lctable = agg1;
        }
        setLops(lctable);
        // to introduce reblock lop since table itself outputs in blocked format if dims known.
        if (!dimsKnown() && !_dimInputsPresent) {
            setRequiresReblock(true);
        }
    }
}
Also used : Group(org.apache.sysml.lops.Group) Lop(org.apache.sysml.lops.Lop) Ctable(org.apache.sysml.lops.Ctable) DataType(org.apache.sysml.parser.Expression.DataType) ExecType(org.apache.sysml.lops.LopProperties.ExecType) Aggregate(org.apache.sysml.lops.Aggregate)

Example 40 with DataType

use of org.apache.sysml.parser.Expression.DataType in project incubator-systemml by apache.

the class InterProceduralAnalysis method extractFunctionCallReturnStatistics.

/**
 * Extract return variable statistics from this function into the
 * calling program.
 *
 * @param fstmt  The function statement.
 * @param fop  The function op.
 * @param tmpVars  Function's map of variables eligible for
 *                    extraction.
 * @param callVars  Calling program's map of variables.
 * @param overwrite  Whether or not to overwrite variables in the
 *                      calling program's variable map.
 */
private static void extractFunctionCallReturnStatistics(FunctionStatement fstmt, FunctionOp fop, LocalVariableMap tmpVars, LocalVariableMap callVars, boolean overwrite) {
    ArrayList<DataIdentifier> foutputOps = fstmt.getOutputParams();
    String[] outputVars = fop.getOutputVariableNames();
    String fkey = fop.getFunctionKey();
    try {
        for (int i = 0; i < foutputOps.size(); i++) {
            DataIdentifier di = foutputOps.get(i);
            // name in function signature
            String fvarname = di.getName();
            // name in calling program
            String pvarname = outputVars[i];
            // output, remove that variable from the calling program's variable map.
            if (callVars.keySet().contains(pvarname)) {
                DataType fdataType = di.getDataType();
                DataType pdataType = callVars.get(pvarname).getDataType();
                if (fdataType != pdataType) {
                    // datatype has changed, and the calling program is reassigning the
                    // the variable, so remove it from the calling variable map
                    callVars.remove(pvarname);
                }
            }
            // Update or add to the calling program's variable map.
            if (di.getDataType() == DataType.MATRIX && tmpVars.keySet().contains(fvarname)) {
                MatrixObject moIn = (MatrixObject) tmpVars.get(fvarname);
                if (// not existing so far
                !callVars.keySet().contains(pvarname) || overwrite) {
                    MatrixObject moOut = createOutputMatrix(moIn.getNumRows(), moIn.getNumColumns(), moIn.getNnz());
                    callVars.put(pvarname, moOut);
                } else // already existing: take largest
                {
                    Data dat = callVars.get(pvarname);
                    if (dat instanceof MatrixObject) {
                        MatrixObject moOut = (MatrixObject) dat;
                        MatrixCharacteristics mc = moOut.getMatrixCharacteristics();
                        if (OptimizerUtils.estimateSizeExactSparsity(mc.getRows(), mc.getCols(), (mc.getNonZeros() > 0) ? OptimizerUtils.getSparsity(mc) : 1.0) < OptimizerUtils.estimateSize(moIn.getNumRows(), moIn.getNumColumns())) {
                            // update statistics if necessary
                            mc.setDimension(moIn.getNumRows(), moIn.getNumColumns());
                            mc.setNonZeros(moIn.getNnz());
                        }
                    }
                }
            }
        }
    } catch (Exception ex) {
        throw new HopsException("Failed to extract output statistics of function " + fkey + ".", ex);
    }
}
Also used : DataIdentifier(org.apache.sysml.parser.DataIdentifier) MatrixObject(org.apache.sysml.runtime.controlprogram.caching.MatrixObject) DataType(org.apache.sysml.parser.Expression.DataType) Data(org.apache.sysml.runtime.instructions.cp.Data) HopsException(org.apache.sysml.hops.HopsException) HopsException(org.apache.sysml.hops.HopsException) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics)

Aggregations

DataType (org.apache.sysml.parser.Expression.DataType)59 ValueType (org.apache.sysml.parser.Expression.ValueType)22 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)14 MatrixCharacteristics (org.apache.sysml.runtime.matrix.MatrixCharacteristics)10 DataIdentifier (org.apache.sysml.parser.DataIdentifier)8 CPOperand (org.apache.sysml.runtime.instructions.cp.CPOperand)8 Operator (org.apache.sysml.runtime.matrix.operators.Operator)8 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)6 Lop (org.apache.sysml.lops.Lop)6 ExecType (org.apache.sysml.lops.LopProperties.ExecType)6 MatrixObject (org.apache.sysml.runtime.controlprogram.caching.MatrixObject)6 Data (org.apache.sysml.runtime.instructions.cp.Data)6 MetaDataFormat (org.apache.sysml.runtime.matrix.MetaDataFormat)6 IOException (java.io.IOException)4 StringTokenizer (java.util.StringTokenizer)4 HopsException (org.apache.sysml.hops.HopsException)4 UnaryOp (org.apache.sysml.hops.UnaryOp)4 Group (org.apache.sysml.lops.Group)4 PDataPartitionFormat (org.apache.sysml.runtime.controlprogram.ParForProgramBlock.PDataPartitionFormat)4 PartitionFormat (org.apache.sysml.runtime.controlprogram.ParForProgramBlock.PartitionFormat)4