Search in sources :

Example 11 with TernaryOp

use of org.apache.sysml.hops.TernaryOp in project incubator-systemml by apache.

the class DMLTranslator method processBuiltinFunctionExpression.

/**
 * Construct Hops from parse tree : Process BuiltinFunction Expression in an
 * assignment statement
 *
 * @param source built-in function expression
 * @param target data identifier
 * @param hops map of high-level operators
 * @return high-level operator
 */
private Hop processBuiltinFunctionExpression(BuiltinFunctionExpression source, DataIdentifier target, HashMap<String, Hop> hops) {
    Hop expr = processExpression(source.getFirstExpr(), null, hops);
    Hop expr2 = null;
    if (source.getSecondExpr() != null) {
        expr2 = processExpression(source.getSecondExpr(), null, hops);
    }
    Hop expr3 = null;
    if (source.getThirdExpr() != null) {
        expr3 = processExpression(source.getThirdExpr(), null, hops);
    }
    Hop currBuiltinOp = null;
    if (target == null) {
        target = createTarget(source);
    }
    // Construct the hop based on the type of Builtin function
    switch(source.getOpCode()) {
        case EVAL:
            currBuiltinOp = new NaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOpN.EVAL, processAllExpressions(source.getAllExpr(), hops));
            break;
        case COLSUM:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.SUM, Direction.Col, expr);
            break;
        case COLMAX:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAX, Direction.Col, expr);
            break;
        case COLMIN:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MIN, Direction.Col, expr);
            break;
        case COLMEAN:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN, Direction.Col, expr);
            break;
        case COLSD:
            // colStdDevs = sqrt(colVariances)
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.Col, expr);
            currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
            break;
        case COLVAR:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.Col, expr);
            break;
        case ROWSUM:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.SUM, Direction.Row, expr);
            break;
        case ROWMAX:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAX, Direction.Row, expr);
            break;
        case ROWINDEXMAX:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAXINDEX, Direction.Row, expr);
            break;
        case ROWINDEXMIN:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MININDEX, Direction.Row, expr);
            break;
        case ROWMIN:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MIN, Direction.Row, expr);
            break;
        case ROWMEAN:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN, Direction.Row, expr);
            break;
        case ROWSD:
            // rowStdDevs = sqrt(rowVariances)
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.Row, expr);
            currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
            break;
        case ROWVAR:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.Row, expr);
            break;
        case NROW:
            // If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
            // Else create a UnaryOp so that a control program instruction is generated
            long nRows = expr.getDim1();
            if (nRows == -1) {
                currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.NROW, expr);
            } else {
                currBuiltinOp = new LiteralOp(nRows);
            }
            break;
        case NCOL:
            // If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
            // Else create a UnaryOp so that a control program instruction is generated
            long nCols = expr.getDim2();
            if (nCols == -1) {
                currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.NCOL, expr);
            } else {
                currBuiltinOp = new LiteralOp(nCols);
            }
            break;
        case LENGTH:
            long nRows2 = expr.getDim1();
            long nCols2 = expr.getDim2();
            /* 
			 * If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
			 * Else create a UnaryOp so that a control program instruction is generated
			 */
            if ((nCols2 == -1) || (nRows2 == -1)) {
                currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.LENGTH, expr);
            } else {
                long lval = (nCols2 * nRows2);
                currBuiltinOp = new LiteralOp(lval);
            }
            break;
        case SUM:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.SUM, Direction.RowCol, expr);
            break;
        case MEAN:
            if (expr2 == null) {
                // example: x = mean(Y);
                currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN, Direction.RowCol, expr);
            } else {
                // example: x = mean(Y,W);
                // stable weighted mean is implemented by using centralMoment with order = 0
                Hop orderHop = new LiteralOp(0);
                currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.CENTRALMOMENT, expr, expr2, orderHop);
            }
            break;
        case SD:
            // stdDev = sqrt(variance)
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.RowCol, expr);
            HopRewriteUtils.setOutputParametersForScalar(currBuiltinOp);
            currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
            break;
        case VAR:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.RowCol, expr);
            break;
        case MIN:
            // construct AggUnary for min(X) but BinaryOp for min(X,Y)
            if (expr2 == null) {
                currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MIN, Direction.RowCol, expr);
            } else {
                currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MIN, expr, expr2);
            }
            break;
        case MAX:
            // construct AggUnary for max(X) but BinaryOp for max(X,Y)
            if (expr2 == null) {
                currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAX, Direction.RowCol, expr);
            } else {
                currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MAX, expr, expr2);
            }
            break;
        case PPRED:
            String sop = ((StringIdentifier) source.getThirdExpr()).getValue();
            sop = sop.replace("\"", "");
            OpOp2 operation;
            if (sop.equalsIgnoreCase(">="))
                operation = OpOp2.GREATEREQUAL;
            else if (sop.equalsIgnoreCase(">"))
                operation = OpOp2.GREATER;
            else if (sop.equalsIgnoreCase("<="))
                operation = OpOp2.LESSEQUAL;
            else if (sop.equalsIgnoreCase("<"))
                operation = OpOp2.LESS;
            else if (sop.equalsIgnoreCase("=="))
                operation = OpOp2.EQUAL;
            else if (sop.equalsIgnoreCase("!="))
                operation = OpOp2.NOTEQUAL;
            else {
                LOG.error(source.printErrorLocation() + "Unknown argument (" + sop + ") for PPRED.");
                throw new ParseException(source.printErrorLocation() + "Unknown argument (" + sop + ") for PPRED.");
            }
            currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), operation, expr, expr2);
            break;
        case PROD:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.PROD, Direction.RowCol, expr);
            break;
        case TRACE:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.TRACE, Direction.RowCol, expr);
            break;
        case TRANS:
            currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ReOrgOp.TRANSPOSE, expr);
            break;
        case REV:
            currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ReOrgOp.REV, expr);
            break;
        case CBIND:
        case RBIND:
            OpOp2 appendOp1 = (source.getOpCode() == BuiltinFunctionOp.CBIND) ? OpOp2.CBIND : OpOp2.RBIND;
            OpOpN appendOp2 = (source.getOpCode() == BuiltinFunctionOp.CBIND) ? OpOpN.CBIND : OpOpN.RBIND;
            currBuiltinOp = (source.getAllExpr().length == 2) ? new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), appendOp1, expr, expr2) : new NaryOp(target.getName(), target.getDataType(), target.getValueType(), appendOp2, processAllExpressions(source.getAllExpr(), hops));
            break;
        case DIAG:
            currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ReOrgOp.DIAG, expr);
            break;
        case TABLE:
            // Always a TertiaryOp is created for table().
            // - create a hop for weights, if not provided in the function call.
            int numTableArgs = source._args.length;
            switch(numTableArgs) {
                case 2:
                case 4:
                    // example DML statement: F = ctable(A,B) or F = ctable(A,B,10,15)
                    // here, weight is interpreted as 1.0
                    Hop weightHop = new LiteralOp(1.0);
                    // set dimensions
                    weightHop.setDim1(0);
                    weightHop.setDim2(0);
                    weightHop.setNnz(-1);
                    weightHop.setRowsInBlock(0);
                    weightHop.setColsInBlock(0);
                    if (numTableArgs == 2)
                        currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, weightHop);
                    else {
                        Hop outDim1 = processExpression(source._args[2], null, hops);
                        Hop outDim2 = processExpression(source._args[3], null, hops);
                        currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, weightHop, outDim1, outDim2);
                    }
                    break;
                case 3:
                case 5:
                    // example DML statement: F = ctable(A,B,W) or F = ctable(A,B,W,10,15)
                    if (numTableArgs == 3)
                        currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, expr3);
                    else {
                        Hop outDim1 = processExpression(source._args[3], null, hops);
                        Hop outDim2 = processExpression(source._args[4], null, hops);
                        currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, expr3, outDim1, outDim2);
                    }
                    break;
                default:
                    throw new ParseException("Invalid number of arguments " + numTableArgs + " to table() function.");
            }
            break;
        // data type casts
        case CAST_AS_SCALAR:
            currBuiltinOp = new UnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), Hop.OpOp1.CAST_AS_SCALAR, expr);
            break;
        case CAST_AS_MATRIX:
            currBuiltinOp = new UnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), Hop.OpOp1.CAST_AS_MATRIX, expr);
            break;
        case CAST_AS_FRAME:
            currBuiltinOp = new UnaryOp(target.getName(), DataType.FRAME, target.getValueType(), Hop.OpOp1.CAST_AS_FRAME, expr);
            break;
        // value type casts
        case CAST_AS_DOUBLE:
            currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.DOUBLE, Hop.OpOp1.CAST_AS_DOUBLE, expr);
            break;
        case CAST_AS_INT:
            currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.INT, Hop.OpOp1.CAST_AS_INT, expr);
            break;
        case CAST_AS_BOOLEAN:
            currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.BOOLEAN, Hop.OpOp1.CAST_AS_BOOLEAN, expr);
            break;
        // Boolean binary
        case XOR:
            currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.XOR, expr, expr2);
            break;
        case BITWAND:
            currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.BITWAND, expr, expr2);
            break;
        case BITWOR:
            currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.BITWOR, expr, expr2);
            break;
        case BITWXOR:
            currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.BITWXOR, expr, expr2);
            break;
        case BITWSHIFTL:
            currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.BITWSHIFTL, expr, expr2);
            break;
        case BITWSHIFTR:
            currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.BITWSHIFTR, expr, expr2);
            break;
        case ABS:
        case SIN:
        case COS:
        case TAN:
        case ASIN:
        case ACOS:
        case ATAN:
        case SINH:
        case COSH:
        case TANH:
        case SIGN:
        case SQRT:
        case EXP:
        case ROUND:
        case CEIL:
        case FLOOR:
        case CUMSUM:
        case CUMPROD:
        case CUMMIN:
        case CUMMAX:
            Hop.OpOp1 mathOp1;
            switch(source.getOpCode()) {
                case ABS:
                    mathOp1 = Hop.OpOp1.ABS;
                    break;
                case SIN:
                    mathOp1 = Hop.OpOp1.SIN;
                    break;
                case COS:
                    mathOp1 = Hop.OpOp1.COS;
                    break;
                case TAN:
                    mathOp1 = Hop.OpOp1.TAN;
                    break;
                case ASIN:
                    mathOp1 = Hop.OpOp1.ASIN;
                    break;
                case ACOS:
                    mathOp1 = Hop.OpOp1.ACOS;
                    break;
                case ATAN:
                    mathOp1 = Hop.OpOp1.ATAN;
                    break;
                case SINH:
                    mathOp1 = Hop.OpOp1.SINH;
                    break;
                case COSH:
                    mathOp1 = Hop.OpOp1.COSH;
                    break;
                case TANH:
                    mathOp1 = Hop.OpOp1.TANH;
                    break;
                case SIGN:
                    mathOp1 = Hop.OpOp1.SIGN;
                    break;
                case SQRT:
                    mathOp1 = Hop.OpOp1.SQRT;
                    break;
                case EXP:
                    mathOp1 = Hop.OpOp1.EXP;
                    break;
                case ROUND:
                    mathOp1 = Hop.OpOp1.ROUND;
                    break;
                case CEIL:
                    mathOp1 = Hop.OpOp1.CEIL;
                    break;
                case FLOOR:
                    mathOp1 = Hop.OpOp1.FLOOR;
                    break;
                case CUMSUM:
                    mathOp1 = Hop.OpOp1.CUMSUM;
                    break;
                case CUMPROD:
                    mathOp1 = Hop.OpOp1.CUMPROD;
                    break;
                case CUMMIN:
                    mathOp1 = Hop.OpOp1.CUMMIN;
                    break;
                case CUMMAX:
                    mathOp1 = Hop.OpOp1.CUMMAX;
                    break;
                default:
                    LOG.error(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
                    throw new ParseException(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
            }
            currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), mathOp1, expr);
            break;
        case LOG:
            if (expr2 == null) {
                Hop.OpOp1 mathOp2;
                switch(source.getOpCode()) {
                    case LOG:
                        mathOp2 = Hop.OpOp1.LOG;
                        break;
                    default:
                        LOG.error(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
                        throw new ParseException(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
                }
                currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), mathOp2, expr);
            } else {
                Hop.OpOp2 mathOp3;
                switch(source.getOpCode()) {
                    case LOG:
                        mathOp3 = Hop.OpOp2.LOG;
                        break;
                    default:
                        LOG.error(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
                        throw new ParseException(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
                }
                currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), mathOp3, expr, expr2);
            }
            break;
        case MOMENT:
            if (expr3 == null) {
                currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.CENTRALMOMENT, expr, expr2);
            } else {
                currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.CENTRALMOMENT, expr, expr2, expr3);
            }
            break;
        case COV:
            if (expr3 == null) {
                currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.COVARIANCE, expr, expr2);
            } else {
                currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.COVARIANCE, expr, expr2, expr3);
            }
            break;
        case QUANTILE:
            if (expr3 == null) {
                currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.QUANTILE, expr, expr2);
            } else {
                currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.QUANTILE, expr, expr2, expr3);
            }
            break;
        case INTERQUANTILE:
            if (expr3 == null) {
                currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.INTERQUANTILE, expr, expr2);
            } else {
                currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.INTERQUANTILE, expr, expr2, expr3);
            }
            break;
        case IQM:
            if (expr2 == null) {
                currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.IQM, expr);
            } else {
                currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.IQM, expr, expr2);
            }
            break;
        case MEDIAN:
            if (expr2 == null) {
                currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.MEDIAN, expr);
            } else {
                currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.MEDIAN, expr, expr2);
            }
            break;
        case IFELSE:
            currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.IFELSE, expr, expr2, expr3);
            break;
        case SEQ:
            HashMap<String, Hop> randParams = new HashMap<>();
            randParams.put(Statement.SEQ_FROM, expr);
            randParams.put(Statement.SEQ_TO, expr2);
            randParams.put(Statement.SEQ_INCR, (expr3 != null) ? expr3 : new LiteralOp(1));
            // note incr: default -1 (for from>to) handled during runtime
            currBuiltinOp = new DataGenOp(DataGenMethod.SEQ, target, randParams);
            break;
        case SAMPLE:
            {
                Expression[] in = source.getAllExpr();
                // arguments: range/size/replace/seed; defaults: replace=FALSE
                HashMap<String, Hop> tmpparams = new HashMap<>();
                // range
                tmpparams.put(DataExpression.RAND_MAX, expr);
                tmpparams.put(DataExpression.RAND_ROWS, expr2);
                tmpparams.put(DataExpression.RAND_COLS, new LiteralOp(1));
                if (in.length == 4) {
                    tmpparams.put(DataExpression.RAND_PDF, expr3);
                    Hop seed = processExpression(in[3], null, hops);
                    tmpparams.put(DataExpression.RAND_SEED, seed);
                } else if (in.length == 3) {
                    // check if the third argument is "replace" or "seed"
                    if (expr3.getValueType() == ValueType.BOOLEAN) {
                        tmpparams.put(DataExpression.RAND_PDF, expr3);
                        tmpparams.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED));
                    } else if (expr3.getValueType() == ValueType.INT) {
                        tmpparams.put(DataExpression.RAND_PDF, new LiteralOp(false));
                        tmpparams.put(DataExpression.RAND_SEED, expr3);
                    } else
                        throw new HopsException("Invalid input type " + expr3.getValueType() + " in sample().");
                } else if (in.length == 2) {
                    tmpparams.put(DataExpression.RAND_PDF, new LiteralOp(false));
                    tmpparams.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED));
                }
                currBuiltinOp = new DataGenOp(DataGenMethod.SAMPLE, target, tmpparams);
                break;
            }
        case SOLVE:
            currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.SOLVE, expr, expr2);
            break;
        case INVERSE:
            currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.INVERSE, expr);
            break;
        case CHOLESKY:
            currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.CHOLESKY, expr);
            break;
        case OUTER:
            if (!(expr3 instanceof LiteralOp))
                throw new HopsException("Operator for outer builtin function must be a constant: " + expr3);
            OpOp2 op = Hop.getOpOp2ForOuterVectorOperation(((LiteralOp) expr3).getStringValue());
            if (op == null)
                throw new HopsException("Unsupported outer vector binary operation: " + ((LiteralOp) expr3).getStringValue());
            currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), op, expr, expr2);
            // flag op as specific outer vector operation
            ((BinaryOp) currBuiltinOp).setOuterVectorOperation(true);
            // force size reevaluation according to 'outer' flag otherwise danger of incorrect dims
            currBuiltinOp.refreshSizeInformation();
            break;
        case CONV2D:
            {
                Hop image = expr;
                ArrayList<Hop> inHops1 = getALHopsForConvOp(image, source, 1, hops);
                currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.DIRECT_CONV2D, inHops1);
                setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
                break;
            }
        case BIAS_ADD:
            {
                ArrayList<Hop> inHops1 = new ArrayList<>();
                inHops1.add(expr);
                inHops1.add(expr2);
                currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.BIAS_ADD, inHops1);
                setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
                break;
            }
        case BIAS_MULTIPLY:
            {
                ArrayList<Hop> inHops1 = new ArrayList<>();
                inHops1.add(expr);
                inHops1.add(expr2);
                currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.BIAS_MULTIPLY, inHops1);
                setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
                break;
            }
        case AVG_POOL:
        case MAX_POOL:
            {
                Hop image = expr;
                ArrayList<Hop> inHops1 = getALHopsForPoolingForwardIM2COL(image, source, 1, hops);
                if (source.getOpCode() == BuiltinFunctionOp.MAX_POOL)
                    currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.MAX_POOLING, inHops1);
                else
                    currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.AVG_POOLING, inHops1);
                setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
                break;
            }
        case AVG_POOL_BACKWARD:
        case MAX_POOL_BACKWARD:
            {
                Hop image = expr;
                // process dout as well
                ArrayList<Hop> inHops1 = getALHopsForConvOpPoolingCOL2IM(image, source, 1, hops);
                if (source.getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD)
                    currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.MAX_POOLING_BACKWARD, inHops1);
                else
                    currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.AVG_POOLING_BACKWARD, inHops1);
                setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
                break;
            }
        case CONV2D_BACKWARD_FILTER:
            {
                Hop image = expr;
                ArrayList<Hop> inHops1 = getALHopsForConvOp(image, source, 1, hops);
                currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.DIRECT_CONV2D_BACKWARD_FILTER, inHops1);
                setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
                break;
            }
        case CONV2D_BACKWARD_DATA:
            {
                Hop image = expr;
                ArrayList<Hop> inHops1 = getALHopsForConvOp(image, source, 1, hops);
                currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.DIRECT_CONV2D_BACKWARD_DATA, inHops1);
                setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
                break;
            }
        default:
            throw new ParseException("Unsupported builtin function type: " + source.getOpCode());
    }
    boolean isConvolution = source.getOpCode() == BuiltinFunctionOp.CONV2D || source.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_DATA || source.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_FILTER || source.getOpCode() == BuiltinFunctionOp.MAX_POOL || source.getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD || source.getOpCode() == BuiltinFunctionOp.AVG_POOL || source.getOpCode() == BuiltinFunctionOp.AVG_POOL_BACKWARD;
    if (!isConvolution) {
        // Since the dimension of output doesnot match that of input variable for these operations
        setIdentifierParams(currBuiltinOp, source.getOutput());
    }
    currBuiltinOp.setParseInfo(source);
    return currBuiltinOp;
}
Also used : OpOpN(org.apache.sysml.hops.Hop.OpOpN) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) OpOp2(org.apache.sysml.hops.Hop.OpOp2) DataGenOp(org.apache.sysml.hops.DataGenOp) ReorgOp(org.apache.sysml.hops.ReorgOp) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop) HopsException(org.apache.sysml.hops.HopsException) TernaryOp(org.apache.sysml.hops.TernaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) ConvolutionOp(org.apache.sysml.hops.ConvolutionOp) NaryOp(org.apache.sysml.hops.NaryOp) OpOp2(org.apache.sysml.hops.Hop.OpOp2)

Example 12 with TernaryOp

use of org.apache.sysml.hops.TernaryOp in project incubator-systemml by apache.

the class RewriteSplitDagDataDependentOperators method rCollectDataDependentOperators.

private void rCollectDataDependentOperators(Hop hop, ArrayList<Hop> cand) {
    if (hop.isVisited())
        return;
    // prevent unnecessary dag split (dims known or no consumer operations)
    boolean noSplitRequired = (hop.dimsKnown() || HopRewriteUtils.hasOnlyWriteParents(hop, true, true));
    boolean investigateChilds = true;
    // #1 removeEmpty
    if (hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp) hop).getOp() == ParamBuiltinOp.RMEMPTY && !noSplitRequired && !(hop.getParent().size() == 1 && hop.getParent().get(0) instanceof TernaryOp && ((TernaryOp) hop.getParent().get(0)).isMatrixIgnoreZeroRewriteApplicable())) {
        ParameterizedBuiltinOp pbhop = (ParameterizedBuiltinOp) hop;
        cand.add(pbhop);
        investigateChilds = false;
        // keep interesting consumer information, flag hops accordingly
        boolean noEmptyBlocks = true;
        boolean onlyPMM = true;
        boolean diagInput = pbhop.isTargetDiagInput();
        for (Hop p : hop.getParent()) {
            // list of operators without need for empty blocks to be extended as needed
            noEmptyBlocks &= (p instanceof AggBinaryOp && hop == p.getInput().get(0) || HopRewriteUtils.isUnary(p, OpOp1.NROW));
            onlyPMM &= (p instanceof AggBinaryOp && hop == p.getInput().get(0));
        }
        pbhop.setOutputEmptyBlocks(!noEmptyBlocks);
        if (onlyPMM && diagInput) {
            if (ConfigurationManager.isDynamicRecompilation())
                pbhop.setOutputPermutationMatrix(true);
            for (Hop p : hop.getParent()) ((AggBinaryOp) p).setHasLeftPMInput(true);
        }
    }
    // #2 ctable with unknown dims
    if (HopRewriteUtils.isTernary(hop, OpOp3.CTABLE) && // dims not provided
    hop.getInput().size() < 4 && !noSplitRequired) {
        cand.add(hop);
        investigateChilds = false;
        // keep interesting consumer information, flag hops accordingly
        boolean onlyPMM = true;
        for (Hop p : hop.getParent()) {
            onlyPMM &= (p instanceof AggBinaryOp && hop == p.getInput().get(0));
        }
        if (onlyPMM && HopRewriteUtils.isBasic1NSequence(hop.getInput().get(0)))
            hop.setOutputEmptyBlocks(false);
    }
    // #3 orderby childs computed in same DAG
    if (HopRewriteUtils.isReorg(hop, ReOrgOp.SORT)) {
        // params 'decreasing' / 'indexreturn'
        for (int i = 2; i <= 3; i++) {
            Hop c = hop.getInput().get(i);
            if (!(c instanceof LiteralOp || c instanceof DataOp)) {
                cand.add(c);
                c.setVisited();
                investigateChilds = false;
            }
        }
    }
    // #4 second-order eval function
    if (HopRewriteUtils.isNary(hop, OpOpN.EVAL) && !noSplitRequired) {
        cand.add(hop);
        investigateChilds = false;
    }
    // otherwise, processed by recursive rule application)
    if (investigateChilds && hop.getInput() != null)
        for (Hop c : hop.getInput()) rCollectDataDependentOperators(c, cand);
    hop.setVisited();
}
Also used : ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) DataOp(org.apache.sysml.hops.DataOp) TernaryOp(org.apache.sysml.hops.TernaryOp)

Example 13 with TernaryOp

use of org.apache.sysml.hops.TernaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyTableSeqExpand.

private static Hop simplifyTableSeqExpand(Hop parent, Hop hi, int pos) {
    // note: this rewrite supports both left/right sequence
    if (// table without weights
    hi instanceof TernaryOp && hi.getInput().size() == 5 && // i.e., weight of 1
    HopRewriteUtils.isLiteralOfValue(hi.getInput().get(2), 1)) {
        Hop first = hi.getInput().get(0);
        Hop second = hi.getInput().get(1);
        // pattern a: table(seq(1,nrow(v)), v, nrow(v), m, 1)
        if (HopRewriteUtils.isBasic1NSequence(first, second, true) && HopRewriteUtils.isSizeExpressionOf(hi.getInput().get(3), second, true)) {
            // setup input parameter hops
            HashMap<String, Hop> args = new HashMap<>();
            args.put("target", second);
            args.put("max", hi.getInput().get(4));
            args.put("dir", new LiteralOp("cols"));
            args.put("ignore", new LiteralOp(false));
            args.put("cast", new LiteralOp(true));
            // create new hop
            ParameterizedBuiltinOp pbop = HopRewriteUtils.createParameterizedBuiltinOp(second, args, ParamBuiltinOp.REXPAND);
            HopRewriteUtils.replaceChildReference(parent, hi, pbop, pos);
            HopRewriteUtils.cleanupUnreferenced(hi);
            hi = pbop;
            LOG.debug("Applied simplifyTableSeqExpand1 (line " + hi.getBeginLine() + ")");
        } else // pattern b: table(v, seq(1,nrow(v)), m, nrow(v))
        if (HopRewriteUtils.isBasic1NSequence(second, first, true) && HopRewriteUtils.isSizeExpressionOf(hi.getInput().get(4), first, true)) {
            // setup input parameter hops
            HashMap<String, Hop> args = new HashMap<>();
            args.put("target", first);
            args.put("max", hi.getInput().get(3));
            args.put("dir", new LiteralOp("rows"));
            args.put("ignore", new LiteralOp(false));
            args.put("cast", new LiteralOp(true));
            // create new hop
            ParameterizedBuiltinOp pbop = HopRewriteUtils.createParameterizedBuiltinOp(first, args, ParamBuiltinOp.REXPAND);
            HopRewriteUtils.replaceChildReference(parent, hi, pbop, pos);
            HopRewriteUtils.cleanupUnreferenced(hi);
            hi = pbop;
            LOG.debug("Applied simplifyTableSeqExpand2 (line " + hi.getBeginLine() + ")");
        }
    }
    return hi;
}
Also used : ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) HashMap(java.util.HashMap) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) TernaryOp(org.apache.sysml.hops.TernaryOp)

Example 14 with TernaryOp

use of org.apache.sysml.hops.TernaryOp in project systemml by apache.

the class DMLTranslator method processBuiltinFunctionExpression.

/**
 * Construct Hops from parse tree : Process BuiltinFunction Expression in an
 * assignment statement
 *
 * @param source built-in function expression
 * @param target data identifier
 * @param hops map of high-level operators
 * @return high-level operator
 */
private Hop processBuiltinFunctionExpression(BuiltinFunctionExpression source, DataIdentifier target, HashMap<String, Hop> hops) {
    Hop expr = processExpression(source.getFirstExpr(), null, hops);
    Hop expr2 = null;
    if (source.getSecondExpr() != null) {
        expr2 = processExpression(source.getSecondExpr(), null, hops);
    }
    Hop expr3 = null;
    if (source.getThirdExpr() != null) {
        expr3 = processExpression(source.getThirdExpr(), null, hops);
    }
    Hop currBuiltinOp = null;
    if (target == null) {
        target = createTarget(source);
    }
    // Construct the hop based on the type of Builtin function
    switch(source.getOpCode()) {
        case EVAL:
            currBuiltinOp = new NaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOpN.EVAL, processAllExpressions(source.getAllExpr(), hops));
            break;
        case COLSUM:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.SUM, Direction.Col, expr);
            break;
        case COLMAX:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAX, Direction.Col, expr);
            break;
        case COLMIN:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MIN, Direction.Col, expr);
            break;
        case COLMEAN:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN, Direction.Col, expr);
            break;
        case COLSD:
            // colStdDevs = sqrt(colVariances)
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.Col, expr);
            currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
            break;
        case COLVAR:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.Col, expr);
            break;
        case ROWSUM:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.SUM, Direction.Row, expr);
            break;
        case ROWMAX:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAX, Direction.Row, expr);
            break;
        case ROWINDEXMAX:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAXINDEX, Direction.Row, expr);
            break;
        case ROWINDEXMIN:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MININDEX, Direction.Row, expr);
            break;
        case ROWMIN:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MIN, Direction.Row, expr);
            break;
        case ROWMEAN:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN, Direction.Row, expr);
            break;
        case ROWSD:
            // rowStdDevs = sqrt(rowVariances)
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.Row, expr);
            currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
            break;
        case ROWVAR:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.Row, expr);
            break;
        case NROW:
            // If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
            // Else create a UnaryOp so that a control program instruction is generated
            long nRows = expr.getDim1();
            if (nRows == -1) {
                currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.NROW, expr);
            } else {
                currBuiltinOp = new LiteralOp(nRows);
            }
            break;
        case NCOL:
            // If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
            // Else create a UnaryOp so that a control program instruction is generated
            long nCols = expr.getDim2();
            if (nCols == -1) {
                currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.NCOL, expr);
            } else {
                currBuiltinOp = new LiteralOp(nCols);
            }
            break;
        case LENGTH:
            long nRows2 = expr.getDim1();
            long nCols2 = expr.getDim2();
            /* 
			 * If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
			 * Else create a UnaryOp so that a control program instruction is generated
			 */
            if ((nCols2 == -1) || (nRows2 == -1)) {
                currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.LENGTH, expr);
            } else {
                long lval = (nCols2 * nRows2);
                currBuiltinOp = new LiteralOp(lval);
            }
            break;
        case EXISTS:
            currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.EXISTS, expr);
            break;
        case SUM:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.SUM, Direction.RowCol, expr);
            break;
        case MEAN:
            if (expr2 == null) {
                // example: x = mean(Y);
                currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN, Direction.RowCol, expr);
            } else {
                // example: x = mean(Y,W);
                // stable weighted mean is implemented by using centralMoment with order = 0
                Hop orderHop = new LiteralOp(0);
                currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.CENTRALMOMENT, expr, expr2, orderHop);
            }
            break;
        case SD:
            // stdDev = sqrt(variance)
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.RowCol, expr);
            HopRewriteUtils.setOutputParametersForScalar(currBuiltinOp);
            currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
            break;
        case VAR:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.RowCol, expr);
            break;
        case MIN:
            // construct AggUnary for min(X) but BinaryOp for min(X,Y)
            if (expr2 == null) {
                currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MIN, Direction.RowCol, expr);
            } else {
                currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MIN, expr, expr2);
            }
            break;
        case MAX:
            // construct AggUnary for max(X) but BinaryOp for max(X,Y)
            if (expr2 == null) {
                currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAX, Direction.RowCol, expr);
            } else {
                currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MAX, expr, expr2);
            }
            break;
        case PPRED:
            String sop = ((StringIdentifier) source.getThirdExpr()).getValue();
            sop = sop.replace("\"", "");
            OpOp2 operation;
            if (sop.equalsIgnoreCase(">="))
                operation = OpOp2.GREATEREQUAL;
            else if (sop.equalsIgnoreCase(">"))
                operation = OpOp2.GREATER;
            else if (sop.equalsIgnoreCase("<="))
                operation = OpOp2.LESSEQUAL;
            else if (sop.equalsIgnoreCase("<"))
                operation = OpOp2.LESS;
            else if (sop.equalsIgnoreCase("=="))
                operation = OpOp2.EQUAL;
            else if (sop.equalsIgnoreCase("!="))
                operation = OpOp2.NOTEQUAL;
            else {
                LOG.error(source.printErrorLocation() + "Unknown argument (" + sop + ") for PPRED.");
                throw new ParseException(source.printErrorLocation() + "Unknown argument (" + sop + ") for PPRED.");
            }
            currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), operation, expr, expr2);
            break;
        case PROD:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.PROD, Direction.RowCol, expr);
            break;
        case TRACE:
            currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.TRACE, Direction.RowCol, expr);
            break;
        case TRANS:
            currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ReOrgOp.TRANSPOSE, expr);
            break;
        case REV:
            currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ReOrgOp.REV, expr);
            break;
        case CBIND:
        case RBIND:
            OpOp2 appendOp1 = (source.getOpCode() == BuiltinFunctionOp.CBIND) ? OpOp2.CBIND : OpOp2.RBIND;
            OpOpN appendOp2 = (source.getOpCode() == BuiltinFunctionOp.CBIND) ? OpOpN.CBIND : OpOpN.RBIND;
            currBuiltinOp = (source.getAllExpr().length == 2) ? new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), appendOp1, expr, expr2) : new NaryOp(target.getName(), target.getDataType(), target.getValueType(), appendOp2, processAllExpressions(source.getAllExpr(), hops));
            break;
        case DIAG:
            currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ReOrgOp.DIAG, expr);
            break;
        case TABLE:
            // Always a TertiaryOp is created for table().
            // - create a hop for weights, if not provided in the function call.
            int numTableArgs = source._args.length;
            switch(numTableArgs) {
                case 2:
                case 4:
                    // example DML statement: F = ctable(A,B) or F = ctable(A,B,10,15)
                    // here, weight is interpreted as 1.0
                    Hop weightHop = new LiteralOp(1.0);
                    // set dimensions
                    weightHop.setDim1(0);
                    weightHop.setDim2(0);
                    weightHop.setNnz(-1);
                    weightHop.setRowsInBlock(0);
                    weightHop.setColsInBlock(0);
                    if (numTableArgs == 2)
                        currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, weightHop);
                    else {
                        Hop outDim1 = processExpression(source._args[2], null, hops);
                        Hop outDim2 = processExpression(source._args[3], null, hops);
                        currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, weightHop, outDim1, outDim2);
                    }
                    break;
                case 3:
                case 5:
                    // example DML statement: F = ctable(A,B,W) or F = ctable(A,B,W,10,15)
                    if (numTableArgs == 3)
                        currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, expr3);
                    else {
                        Hop outDim1 = processExpression(source._args[3], null, hops);
                        Hop outDim2 = processExpression(source._args[4], null, hops);
                        currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, expr3, outDim1, outDim2);
                    }
                    break;
                default:
                    throw new ParseException("Invalid number of arguments " + numTableArgs + " to table() function.");
            }
            break;
        // data type casts
        case CAST_AS_SCALAR:
            currBuiltinOp = new UnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), Hop.OpOp1.CAST_AS_SCALAR, expr);
            break;
        case CAST_AS_MATRIX:
            currBuiltinOp = new UnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), Hop.OpOp1.CAST_AS_MATRIX, expr);
            break;
        case CAST_AS_FRAME:
            currBuiltinOp = new UnaryOp(target.getName(), DataType.FRAME, target.getValueType(), Hop.OpOp1.CAST_AS_FRAME, expr);
            break;
        // value type casts
        case CAST_AS_DOUBLE:
            currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.DOUBLE, Hop.OpOp1.CAST_AS_DOUBLE, expr);
            break;
        case CAST_AS_INT:
            currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.INT, Hop.OpOp1.CAST_AS_INT, expr);
            break;
        case CAST_AS_BOOLEAN:
            currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.BOOLEAN, Hop.OpOp1.CAST_AS_BOOLEAN, expr);
            break;
        // Boolean binary
        case XOR:
            currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.XOR, expr, expr2);
            break;
        case BITWAND:
            currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.BITWAND, expr, expr2);
            break;
        case BITWOR:
            currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.BITWOR, expr, expr2);
            break;
        case BITWXOR:
            currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.BITWXOR, expr, expr2);
            break;
        case BITWSHIFTL:
            currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.BITWSHIFTL, expr, expr2);
            break;
        case BITWSHIFTR:
            currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.BITWSHIFTR, expr, expr2);
            break;
        case ABS:
        case SIN:
        case COS:
        case TAN:
        case ASIN:
        case ACOS:
        case ATAN:
        case SINH:
        case COSH:
        case TANH:
        case SIGN:
        case SQRT:
        case EXP:
        case ROUND:
        case CEIL:
        case FLOOR:
        case CUMSUM:
        case CUMPROD:
        case CUMMIN:
        case CUMMAX:
            Hop.OpOp1 mathOp1;
            switch(source.getOpCode()) {
                case ABS:
                    mathOp1 = Hop.OpOp1.ABS;
                    break;
                case SIN:
                    mathOp1 = Hop.OpOp1.SIN;
                    break;
                case COS:
                    mathOp1 = Hop.OpOp1.COS;
                    break;
                case TAN:
                    mathOp1 = Hop.OpOp1.TAN;
                    break;
                case ASIN:
                    mathOp1 = Hop.OpOp1.ASIN;
                    break;
                case ACOS:
                    mathOp1 = Hop.OpOp1.ACOS;
                    break;
                case ATAN:
                    mathOp1 = Hop.OpOp1.ATAN;
                    break;
                case SINH:
                    mathOp1 = Hop.OpOp1.SINH;
                    break;
                case COSH:
                    mathOp1 = Hop.OpOp1.COSH;
                    break;
                case TANH:
                    mathOp1 = Hop.OpOp1.TANH;
                    break;
                case SIGN:
                    mathOp1 = Hop.OpOp1.SIGN;
                    break;
                case SQRT:
                    mathOp1 = Hop.OpOp1.SQRT;
                    break;
                case EXP:
                    mathOp1 = Hop.OpOp1.EXP;
                    break;
                case ROUND:
                    mathOp1 = Hop.OpOp1.ROUND;
                    break;
                case CEIL:
                    mathOp1 = Hop.OpOp1.CEIL;
                    break;
                case FLOOR:
                    mathOp1 = Hop.OpOp1.FLOOR;
                    break;
                case CUMSUM:
                    mathOp1 = Hop.OpOp1.CUMSUM;
                    break;
                case CUMPROD:
                    mathOp1 = Hop.OpOp1.CUMPROD;
                    break;
                case CUMMIN:
                    mathOp1 = Hop.OpOp1.CUMMIN;
                    break;
                case CUMMAX:
                    mathOp1 = Hop.OpOp1.CUMMAX;
                    break;
                default:
                    LOG.error(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
                    throw new ParseException(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
            }
            currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), mathOp1, expr);
            break;
        case LOG:
            if (expr2 == null) {
                Hop.OpOp1 mathOp2;
                switch(source.getOpCode()) {
                    case LOG:
                        mathOp2 = Hop.OpOp1.LOG;
                        break;
                    default:
                        LOG.error(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
                        throw new ParseException(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
                }
                currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), mathOp2, expr);
            } else {
                Hop.OpOp2 mathOp3;
                switch(source.getOpCode()) {
                    case LOG:
                        mathOp3 = Hop.OpOp2.LOG;
                        break;
                    default:
                        LOG.error(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
                        throw new ParseException(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
                }
                currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), mathOp3, expr, expr2);
            }
            break;
        case MOMENT:
            if (expr3 == null) {
                currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.CENTRALMOMENT, expr, expr2);
            } else {
                currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.CENTRALMOMENT, expr, expr2, expr3);
            }
            break;
        case COV:
            if (expr3 == null) {
                currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.COVARIANCE, expr, expr2);
            } else {
                currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.COVARIANCE, expr, expr2, expr3);
            }
            break;
        case QUANTILE:
            if (expr3 == null) {
                currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.QUANTILE, expr, expr2);
            } else {
                currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.QUANTILE, expr, expr2, expr3);
            }
            break;
        case INTERQUANTILE:
            if (expr3 == null) {
                currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.INTERQUANTILE, expr, expr2);
            } else {
                currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.INTERQUANTILE, expr, expr2, expr3);
            }
            break;
        case IQM:
            if (expr2 == null) {
                currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.IQM, expr);
            } else {
                currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.IQM, expr, expr2);
            }
            break;
        case MEDIAN:
            if (expr2 == null) {
                currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.MEDIAN, expr);
            } else {
                currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.MEDIAN, expr, expr2);
            }
            break;
        case IFELSE:
            currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.IFELSE, expr, expr2, expr3);
            break;
        case SEQ:
            HashMap<String, Hop> randParams = new HashMap<>();
            randParams.put(Statement.SEQ_FROM, expr);
            randParams.put(Statement.SEQ_TO, expr2);
            randParams.put(Statement.SEQ_INCR, (expr3 != null) ? expr3 : new LiteralOp(1));
            // note incr: default -1 (for from>to) handled during runtime
            currBuiltinOp = new DataGenOp(DataGenMethod.SEQ, target, randParams);
            break;
        case SAMPLE:
            {
                Expression[] in = source.getAllExpr();
                // arguments: range/size/replace/seed; defaults: replace=FALSE
                HashMap<String, Hop> tmpparams = new HashMap<>();
                // range
                tmpparams.put(DataExpression.RAND_MAX, expr);
                tmpparams.put(DataExpression.RAND_ROWS, expr2);
                tmpparams.put(DataExpression.RAND_COLS, new LiteralOp(1));
                if (in.length == 4) {
                    tmpparams.put(DataExpression.RAND_PDF, expr3);
                    Hop seed = processExpression(in[3], null, hops);
                    tmpparams.put(DataExpression.RAND_SEED, seed);
                } else if (in.length == 3) {
                    // check if the third argument is "replace" or "seed"
                    if (expr3.getValueType() == ValueType.BOOLEAN) {
                        tmpparams.put(DataExpression.RAND_PDF, expr3);
                        tmpparams.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED));
                    } else if (expr3.getValueType() == ValueType.INT) {
                        tmpparams.put(DataExpression.RAND_PDF, new LiteralOp(false));
                        tmpparams.put(DataExpression.RAND_SEED, expr3);
                    } else
                        throw new HopsException("Invalid input type " + expr3.getValueType() + " in sample().");
                } else if (in.length == 2) {
                    tmpparams.put(DataExpression.RAND_PDF, new LiteralOp(false));
                    tmpparams.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED));
                }
                currBuiltinOp = new DataGenOp(DataGenMethod.SAMPLE, target, tmpparams);
                break;
            }
        case SOLVE:
            currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.SOLVE, expr, expr2);
            break;
        case INVERSE:
            currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.INVERSE, expr);
            break;
        case CHOLESKY:
            currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.CHOLESKY, expr);
            break;
        case OUTER:
            if (!(expr3 instanceof LiteralOp))
                throw new HopsException("Operator for outer builtin function must be a constant: " + expr3);
            OpOp2 op = Hop.getOpOp2ForOuterVectorOperation(((LiteralOp) expr3).getStringValue());
            if (op == null)
                throw new HopsException("Unsupported outer vector binary operation: " + ((LiteralOp) expr3).getStringValue());
            currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), op, expr, expr2);
            // flag op as specific outer vector operation
            ((BinaryOp) currBuiltinOp).setOuterVectorOperation(true);
            // force size reevaluation according to 'outer' flag otherwise danger of incorrect dims
            currBuiltinOp.refreshSizeInformation();
            break;
        case CONV2D:
            {
                Hop image = expr;
                ArrayList<Hop> inHops1 = getALHopsForConvOp(image, source, 1, hops);
                currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.DIRECT_CONV2D, inHops1);
                setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
                break;
            }
        case BIAS_ADD:
            {
                ArrayList<Hop> inHops1 = new ArrayList<>();
                inHops1.add(expr);
                inHops1.add(expr2);
                currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.BIAS_ADD, inHops1);
                setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
                break;
            }
        case BIAS_MULTIPLY:
            {
                ArrayList<Hop> inHops1 = new ArrayList<>();
                inHops1.add(expr);
                inHops1.add(expr2);
                currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.BIAS_MULTIPLY, inHops1);
                setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
                break;
            }
        case AVG_POOL:
        case MAX_POOL:
            {
                Hop image = expr;
                ArrayList<Hop> inHops1 = getALHopsForPoolingForwardIM2COL(image, source, 1, hops);
                if (source.getOpCode() == BuiltinFunctionOp.MAX_POOL)
                    currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.MAX_POOLING, inHops1);
                else
                    currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.AVG_POOLING, inHops1);
                setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
                break;
            }
        case AVG_POOL_BACKWARD:
        case MAX_POOL_BACKWARD:
            {
                Hop image = expr;
                // process dout as well
                ArrayList<Hop> inHops1 = getALHopsForConvOpPoolingCOL2IM(image, source, 1, hops);
                if (source.getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD)
                    currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.MAX_POOLING_BACKWARD, inHops1);
                else
                    currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.AVG_POOLING_BACKWARD, inHops1);
                setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
                break;
            }
        case CONV2D_BACKWARD_FILTER:
            {
                Hop image = expr;
                ArrayList<Hop> inHops1 = getALHopsForConvOp(image, source, 1, hops);
                currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.DIRECT_CONV2D_BACKWARD_FILTER, inHops1);
                setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
                break;
            }
        case CONV2D_BACKWARD_DATA:
            {
                Hop image = expr;
                ArrayList<Hop> inHops1 = getALHopsForConvOp(image, source, 1, hops);
                currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.DIRECT_CONV2D_BACKWARD_DATA, inHops1);
                setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
                break;
            }
        default:
            throw new ParseException("Unsupported builtin function type: " + source.getOpCode());
    }
    boolean isConvolution = source.getOpCode() == BuiltinFunctionOp.CONV2D || source.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_DATA || source.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_FILTER || source.getOpCode() == BuiltinFunctionOp.MAX_POOL || source.getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD || source.getOpCode() == BuiltinFunctionOp.AVG_POOL || source.getOpCode() == BuiltinFunctionOp.AVG_POOL_BACKWARD;
    if (!isConvolution) {
        // Since the dimension of output doesnot match that of input variable for these operations
        setIdentifierParams(currBuiltinOp, source.getOutput());
    }
    currBuiltinOp.setParseInfo(source);
    return currBuiltinOp;
}
Also used : OpOpN(org.apache.sysml.hops.Hop.OpOpN) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) OpOp2(org.apache.sysml.hops.Hop.OpOp2) DataGenOp(org.apache.sysml.hops.DataGenOp) ReorgOp(org.apache.sysml.hops.ReorgOp) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop) HopsException(org.apache.sysml.hops.HopsException) TernaryOp(org.apache.sysml.hops.TernaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) ConvolutionOp(org.apache.sysml.hops.ConvolutionOp) NaryOp(org.apache.sysml.hops.NaryOp) OpOp2(org.apache.sysml.hops.Hop.OpOp2)

Example 15 with TernaryOp

use of org.apache.sysml.hops.TernaryOp in project systemml by apache.

the class TemplateCell method isValidOperation.

protected static boolean isValidOperation(Hop hop) {
    // prepare indicators for binary operations
    boolean isBinaryMatrixScalar = false;
    boolean isBinaryMatrixVector = false;
    boolean isBinaryMatrixMatrix = false;
    if (hop instanceof BinaryOp && hop.getDataType().isMatrix()) {
        Hop left = hop.getInput().get(0);
        Hop right = hop.getInput().get(1);
        DataType ldt = left.getDataType();
        DataType rdt = right.getDataType();
        isBinaryMatrixScalar = (ldt.isScalar() || rdt.isScalar());
        isBinaryMatrixVector = hop.dimsKnown() && ((ldt.isMatrix() && TemplateUtils.isVectorOrScalar(right)) || (rdt.isMatrix() && TemplateUtils.isVectorOrScalar(left)));
        isBinaryMatrixMatrix = hop.dimsKnown() && HopRewriteUtils.isEqualSize(left, right) && ldt.isMatrix() && rdt.isMatrix();
    }
    // prepare indicators for ternary operations
    boolean isTernaryVectorScalarVector = false;
    boolean isTernaryMatrixScalarMatrixDense = false;
    boolean isTernaryIfElse = (HopRewriteUtils.isTernary(hop, OpOp3.IFELSE) && hop.getDataType().isMatrix());
    if (hop instanceof TernaryOp && hop.getInput().size() == 3 && hop.dimsKnown() && HopRewriteUtils.checkInputDataTypes(hop, DataType.MATRIX, DataType.SCALAR, DataType.MATRIX)) {
        Hop left = hop.getInput().get(0);
        Hop right = hop.getInput().get(2);
        isTernaryVectorScalarVector = TemplateUtils.isVector(left) && TemplateUtils.isVector(right);
        isTernaryMatrixScalarMatrixDense = HopRewriteUtils.isEqualSize(left, right) && !HopRewriteUtils.isSparse(left) && !HopRewriteUtils.isSparse(right);
    }
    // check supported unary, binary, ternary operations
    return hop.getDataType() == DataType.MATRIX && TemplateUtils.isOperationSupported(hop) && (hop instanceof UnaryOp || isBinaryMatrixScalar || isBinaryMatrixVector || isBinaryMatrixMatrix || isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense || isTernaryIfElse || (hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp) hop).getOp() == ParamBuiltinOp.REPLACE));
}
Also used : ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop) DataType(org.apache.sysml.parser.Expression.DataType) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp) TernaryOp(org.apache.sysml.hops.TernaryOp)

Aggregations

TernaryOp (org.apache.sysml.hops.TernaryOp)19 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)15 Hop (org.apache.sysml.hops.Hop)15 LiteralOp (org.apache.sysml.hops.LiteralOp)13 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)13 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)11 BinaryOp (org.apache.sysml.hops.BinaryOp)11 UnaryOp (org.apache.sysml.hops.UnaryOp)11 IndexingOp (org.apache.sysml.hops.IndexingOp)7 ReorgOp (org.apache.sysml.hops.ReorgOp)7 HashMap (java.util.HashMap)4 DataOp (org.apache.sysml.hops.DataOp)4 CNode (org.apache.sysml.hops.codegen.cplan.CNode)4 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)4 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)4 CNodeTernary (org.apache.sysml.hops.codegen.cplan.CNodeTernary)4 TernaryType (org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType)4 CNodeUnary (org.apache.sysml.hops.codegen.cplan.CNodeUnary)4 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)4 ArrayList (java.util.ArrayList)2