Search in sources :

Example 26 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifyBushyBinaryOperation.

/**
 * (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
 * (X+(Y+(Z%*%v))) -> (X+Y)+(Z%*%v)
 *
 * Note: Restriction ba() at leaf and root instead of data at leaf to not reorganize too
 * eagerly, which would loose additional rewrite potential. This rewrite has two goals
 * (1) enable XtwXv, and increase piggybacking potential by creating bushy trees.
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop simplifyBushyBinaryOperation(Hop parent, Hop hi, int pos) {
    if (hi instanceof BinaryOp && parent instanceof AggBinaryOp) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        OpOp2 op = bop.getOp();
        if (left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX && HopRewriteUtils.isValidOp(op, LOOKUP_VALID_ASSOCIATIVE_BINARY)) {
            boolean applied = false;
            if (right instanceof BinaryOp) {
                BinaryOp bop2 = (BinaryOp) right;
                Hop left2 = bop2.getInput().get(0);
                Hop right2 = bop2.getInput().get(1);
                OpOp2 op2 = bop2.getOp();
                if (op == op2 && right2.getDataType() == DataType.MATRIX && (right2 instanceof AggBinaryOp)) {
                    // (X*(Y*op()) -> (X*Y)*op()
                    BinaryOp bop3 = HopRewriteUtils.createBinary(left, left2, op);
                    BinaryOp bop4 = HopRewriteUtils.createBinary(bop3, right2, op);
                    HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, bop2);
                    hi = bop4;
                    applied = true;
                    LOG.debug("Applied simplifyBushyBinaryOperation1");
                }
            }
            if (!applied && left instanceof BinaryOp) {
                BinaryOp bop2 = (BinaryOp) left;
                Hop left2 = bop2.getInput().get(0);
                Hop right2 = bop2.getInput().get(1);
                OpOp2 op2 = bop2.getOp();
                if (op == op2 && left2.getDataType() == DataType.MATRIX && (left2 instanceof AggBinaryOp) && // X not vector, or Y vector
                (right2.getDim2() > 1 || right.getDim2() == 1) && // X not vector, or Y vector
                (right2.getDim1() > 1 || right.getDim1() == 1)) {
                    // ((op()*X)*Y) -> op()*(X*Y)
                    BinaryOp bop3 = HopRewriteUtils.createBinary(right2, right, op);
                    BinaryOp bop4 = HopRewriteUtils.createBinary(left2, bop3, op);
                    HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, bop2);
                    hi = bop4;
                    LOG.debug("Applied simplifyBushyBinaryOperation2");
                }
            }
        }
    }
    return hi;
}
Also used : AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) OpOp2(org.apache.sysml.hops.Hop.OpOp2) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 27 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method canonicalizeMatrixMultScalarAdd.

/**
 * Rewrite to canonicalize all patterns like U%*%V+eps, eps+U%*%V, and
 * U%*%V-eps into the common representation U%*%V+s which simplifies
 * subsequent rewrites (e.g., wdivmm or wcemm with epsilon).
 *
 * @param hi high-level operator
 * @return high-level operator
 */
private static Hop canonicalizeMatrixMultScalarAdd(Hop hi) {
    // pattern: binary operation (+ or -) of matrix mult and scalar
    if (hi instanceof BinaryOp) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = hi.getInput().get(0);
        Hop right = hi.getInput().get(1);
        // pattern: (eps + U%*%V) -> (U%*%V+eps)
        if (left.getDataType().isScalar() && right instanceof AggBinaryOp && bop.getOp() == OpOp2.PLUS) {
            HopRewriteUtils.removeAllChildReferences(bop);
            HopRewriteUtils.addChildReference(bop, right, 0);
            HopRewriteUtils.addChildReference(bop, left, 1);
            LOG.debug("Applied canonicalizeMatrixMultScalarAdd1 (line " + hi.getBeginLine() + ").");
        } else // pattern: (U%*%V - eps) -> (U%*%V + (-eps))
        if (right.getDataType().isScalar() && left instanceof AggBinaryOp && bop.getOp() == OpOp2.MINUS) {
            bop.setOp(OpOp2.PLUS);
            HopRewriteUtils.replaceChildReference(bop, right, HopRewriteUtils.createBinaryMinus(right), 1);
            LOG.debug("Applied canonicalizeMatrixMultScalarAdd2 (line " + hi.getBeginLine() + ").");
        }
    }
    return hi;
}
Also used : AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 28 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class Explain method getNodeLabel.

private static String getNodeLabel(Hop hop) {
    StringBuilder sb = new StringBuilder();
    sb.append(hop.getOpString());
    if (hop instanceof AggBinaryOp) {
        AggBinaryOp aggBinOp = (AggBinaryOp) hop;
        if (aggBinOp.getMMultMethod() != null)
            sb.append(" " + aggBinOp.getMMultMethod().name() + " ");
    }
    // data flow properties
    if (SHOW_DATA_FLOW_PROPERTIES) {
        if (hop.requiresReblock() && hop.requiresCheckpoint())
            sb.append(", rblk,chkpt");
        else if (hop.requiresReblock())
            sb.append(", rblk");
        else if (hop.requiresCheckpoint())
            sb.append(", chkpt");
    }
    if (hop.getFilename() == null) {
        sb.append("[" + hop.getBeginLine() + ":" + hop.getBeginColumn() + "-" + hop.getEndLine() + ":" + hop.getEndColumn() + "]");
    } else {
        sb.append("[" + hop.getFilename() + " " + hop.getBeginLine() + ":" + hop.getBeginColumn() + "-" + hop.getEndLine() + ":" + hop.getEndColumn() + "]");
    }
    if (hop.getUpdateType().isInPlace())
        sb.append("," + hop.getUpdateType().toString().toLowerCase());
    return sb.toString();
}
Also used : AggBinaryOp(org.apache.sysml.hops.AggBinaryOp)

Example 29 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class PlanSelectionFuseCostBased method rGetComputeCosts.

private static void rGetComputeCosts(Hop current, HashSet<Long> partition, HashMap<Long, Double> computeCosts) {
    if (computeCosts.containsKey(current.getHopID()))
        return;
    //recursively process children
    for (Hop c : current.getInput()) rGetComputeCosts(c, partition, computeCosts);
    //get costs for given hop
    double costs = 1;
    if (current instanceof UnaryOp) {
        switch(((UnaryOp) current).getOp()) {
            case ABS:
            case ROUND:
            case CEIL:
            case FLOOR:
            case SIGN:
            case SELP:
                costs = 1;
                break;
            case SPROP:
            case SQRT:
                costs = 2;
                break;
            case EXP:
                costs = 18;
                break;
            case SIGMOID:
                costs = 21;
                break;
            case LOG:
            case LOG_NZ:
                costs = 32;
                break;
            case NCOL:
            case NROW:
            case PRINT:
            case CAST_AS_BOOLEAN:
            case CAST_AS_DOUBLE:
            case CAST_AS_INT:
            case CAST_AS_MATRIX:
            case CAST_AS_SCALAR:
                costs = 1;
                break;
            case SIN:
                costs = 18;
                break;
            case COS:
                costs = 22;
                break;
            case TAN:
                costs = 42;
                break;
            case ASIN:
                costs = 93;
                break;
            case ACOS:
                costs = 103;
                break;
            case ATAN:
                costs = 40;
                break;
            case CUMSUM:
            case CUMMIN:
            case CUMMAX:
            case CUMPROD:
                costs = 1;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((UnaryOp) current).getOp());
        }
    } else if (current instanceof BinaryOp) {
        switch(((BinaryOp) current).getOp()) {
            case MULT:
            case PLUS:
            case MINUS:
            case MIN:
            case MAX:
            case AND:
            case OR:
            case EQUAL:
            case NOTEQUAL:
            case LESS:
            case LESSEQUAL:
            case GREATER:
            case GREATEREQUAL:
            case CBIND:
            case RBIND:
                costs = 1;
                break;
            case INTDIV:
                costs = 6;
                break;
            case MODULUS:
                costs = 8;
                break;
            case DIV:
                costs = 22;
                break;
            case LOG:
            case LOG_NZ:
                costs = 32;
                break;
            case POW:
                costs = (HopRewriteUtils.isLiteralOfValue(current.getInput().get(1), 2) ? 1 : 16);
                break;
            case MINUS_NZ:
            case MINUS1_MULT:
                costs = 2;
                break;
            case CENTRALMOMENT:
                int type = (int) (current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1)) : 2);
                switch(type) {
                    //count
                    case 0:
                        costs = 1;
                        break;
                    //mean
                    case 1:
                        costs = 8;
                        break;
                    //cm2
                    case 2:
                        costs = 16;
                        break;
                    //cm3
                    case 3:
                        costs = 31;
                        break;
                    //cm4
                    case 4:
                        costs = 51;
                        break;
                    //variance
                    case 5:
                        costs = 16;
                        break;
                }
                break;
            case COVARIANCE:
                costs = 23;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((BinaryOp) current).getOp());
        }
    } else if (current instanceof TernaryOp) {
        switch(((TernaryOp) current).getOp()) {
            case PLUS_MULT:
            case MINUS_MULT:
                costs = 2;
                break;
            case CTABLE:
                costs = 3;
                break;
            case CENTRALMOMENT:
                int type = (int) (current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1)) : 2);
                switch(type) {
                    //count
                    case 0:
                        costs = 2;
                        break;
                    //mean
                    case 1:
                        costs = 9;
                        break;
                    //cm2
                    case 2:
                        costs = 17;
                        break;
                    //cm3
                    case 3:
                        costs = 32;
                        break;
                    //cm4
                    case 4:
                        costs = 52;
                        break;
                    //variance
                    case 5:
                        costs = 17;
                        break;
                }
                break;
            case COVARIANCE:
                costs = 23;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((TernaryOp) current).getOp());
        }
    } else if (current instanceof ParameterizedBuiltinOp) {
        costs = 1;
    } else if (current instanceof IndexingOp) {
        costs = 1;
    } else if (current instanceof ReorgOp) {
        costs = 1;
    } else if (current instanceof AggBinaryOp) {
        //matrix vector
        costs = 2;
    } else if (current instanceof AggUnaryOp) {
        switch(((AggUnaryOp) current).getOp()) {
            case SUM:
                costs = 4;
                break;
            case SUM_SQ:
                costs = 5;
                break;
            case MIN:
            case MAX:
                costs = 1;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((AggUnaryOp) current).getOp());
        }
    }
    computeCosts.put(current.getHopID(), costs);
}
Also used : ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) IndexingOp(org.apache.sysml.hops.IndexingOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp) TernaryOp(org.apache.sysml.hops.TernaryOp)

Aggregations

AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)29 Hop (org.apache.sysml.hops.Hop)24 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)17 BinaryOp (org.apache.sysml.hops.BinaryOp)14 LiteralOp (org.apache.sysml.hops.LiteralOp)11 ReorgOp (org.apache.sysml.hops.ReorgOp)8 UnaryOp (org.apache.sysml.hops.UnaryOp)8 TernaryOp (org.apache.sysml.hops.TernaryOp)6 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)6 IndexingOp (org.apache.sysml.hops.IndexingOp)5 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)5 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)5 CNode (org.apache.sysml.hops.codegen.cplan.CNode)4 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)3 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)3 CNodeUnary (org.apache.sysml.hops.codegen.cplan.CNodeUnary)3 DataOp (org.apache.sysml.hops.DataOp)2 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)2 CNodeTernary (org.apache.sysml.hops.codegen.cplan.CNodeTernary)2 TernaryType (org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType)2