Search in sources :

Example 96 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifyBinaryToUnaryOperation.

/**
 * Handle simplification of binary operations (relies on previous common subexpression elimination).
 * At the same time this servers as a canonicalization for more complex rewrites.
 *
 * X+X -> X*2, X*X -> X^2, (X>0)-(X<0) -> sign(X)
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop simplifyBinaryToUnaryOperation(Hop parent, Hop hi, int pos) {
    if (hi instanceof BinaryOp) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = hi.getInput().get(0);
        Hop right = hi.getInput().get(1);
        // patterns: X+X -> X*2, X*X -> X^2,
        if (left == right && left.getDataType() == DataType.MATRIX) {
            // however, we later compile specific LOPS for X*2 and X^2
            if (// X+X -> X*2
            bop.getOp() == OpOp2.PLUS) {
                bop.setOp(OpOp2.MULT);
                HopRewriteUtils.replaceChildReference(hi, right, new LiteralOp(2), 1);
                LOG.debug("Applied simplifyBinaryToUnaryOperation1 (line " + hi.getBeginLine() + ").");
            } else if (// X*X -> X^2
            bop.getOp() == OpOp2.MULT) {
                bop.setOp(OpOp2.POW);
                HopRewriteUtils.replaceChildReference(hi, right, new LiteralOp(2), 1);
                LOG.debug("Applied simplifyBinaryToUnaryOperation2 (line " + hi.getBeginLine() + ").");
            }
        } else // patterns: (X>0)-(X<0) -> sign(X)
        if (bop.getOp() == OpOp2.MINUS && HopRewriteUtils.isBinary(left, OpOp2.GREATER) && HopRewriteUtils.isBinary(right, OpOp2.LESS) && left.getInput().get(0) == right.getInput().get(0) && left.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left.getInput().get(1)) == 0 && right.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) right.getInput().get(1)) == 0) {
            UnaryOp uop = HopRewriteUtils.createUnary(left.getInput().get(0), OpOp1.SIGN);
            HopRewriteUtils.replaceChildReference(parent, hi, uop, pos);
            HopRewriteUtils.cleanupUnreferenced(hi, left, right);
            hi = uop;
            LOG.debug("Applied simplifyBinaryToUnaryOperation3 (line " + hi.getBeginLine() + ").");
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 97 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifyDistributiveBinaryOperation.

/**
 * (X-Y*X) -> (1-Y)*X,    (Y*X-X) -> (Y-1)*X
 * (X+Y*X) -> (1+Y)*X,    (Y*X+X) -> (Y+1)*X
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop simplifyDistributiveBinaryOperation(Hop parent, Hop hi, int pos) {
    if (hi instanceof BinaryOp) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        // (X+Y*X) -> (1+Y)*X,    (Y*X+X) -> (Y+1)*X
        // (X-Y*X) -> (1-Y)*X,    (Y*X-X) -> (Y-1)*X
        boolean applied = false;
        if (left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX && HopRewriteUtils.isValidOp(bop.getOp(), LOOKUP_VALID_DISTRIBUTIVE_BINARY)) {
            Hop X = null;
            Hop Y = null;
            if (// (Y*X-X) -> (Y-1)*X
            HopRewriteUtils.isBinary(left, OpOp2.MULT)) {
                Hop leftC1 = left.getInput().get(0);
                Hop leftC2 = left.getInput().get(1);
                if (leftC1.getDataType() == DataType.MATRIX && leftC2.getDataType() == DataType.MATRIX && (right == leftC1 || right == leftC2) && leftC1 != leftC2) {
                    // any mult order
                    X = right;
                    Y = (right == leftC1) ? leftC2 : leftC1;
                }
                if (X != null) {
                    // rewrite 'binary +/-'
                    LiteralOp literal = new LiteralOp(1);
                    BinaryOp plus = HopRewriteUtils.createBinary(Y, literal, bop.getOp());
                    BinaryOp mult = HopRewriteUtils.createBinary(plus, X, OpOp2.MULT);
                    HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
                    HopRewriteUtils.cleanupUnreferenced(hi, left);
                    hi = mult;
                    applied = true;
                    LOG.debug("Applied simplifyDistributiveBinaryOperation1");
                }
            }
            if (// (X-Y*X) -> (1-Y)*X
            !applied && HopRewriteUtils.isBinary(right, OpOp2.MULT)) {
                Hop rightC1 = right.getInput().get(0);
                Hop rightC2 = right.getInput().get(1);
                if (rightC1.getDataType() == DataType.MATRIX && rightC2.getDataType() == DataType.MATRIX && (left == rightC1 || left == rightC2) && rightC1 != rightC2) {
                    // any mult order
                    X = left;
                    Y = (left == rightC1) ? rightC2 : rightC1;
                }
                if (X != null) {
                    // rewrite '+/- binary'
                    LiteralOp literal = new LiteralOp(1);
                    BinaryOp plus = HopRewriteUtils.createBinary(literal, Y, bop.getOp());
                    BinaryOp mult = HopRewriteUtils.createBinary(plus, X, OpOp2.MULT);
                    HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
                    HopRewriteUtils.cleanupUnreferenced(hi, right);
                    hi = mult;
                    LOG.debug("Applied simplifyDistributiveBinaryOperation2");
                }
            }
        }
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 98 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifyTransposedAppend.

private static Hop simplifyTransposedAppend(Hop parent, Hop hi, int pos) {
    // e.g., t(cbind(t(A),t(B))) --> rbind(A,B), t(rbind(t(A),t(B))) --> cbind(A,B)
    if (// t() rooted
    HopRewriteUtils.isTransposeOperation(hi) && hi.getInput().get(0) instanceof BinaryOp && (// append (cbind/rbind)
    ((BinaryOp) hi.getInput().get(0)).getOp() == OpOp2.CBIND || ((BinaryOp) hi.getInput().get(0)).getOp() == OpOp2.RBIND) && // single consumer of append
    hi.getInput().get(0).getParent().size() == 1) {
        BinaryOp bop = (BinaryOp) hi.getInput().get(0);
        // both inputs transpose ops, where transpose is single consumer
        if (HopRewriteUtils.isTransposeOperation(bop.getInput().get(0), 1) && HopRewriteUtils.isTransposeOperation(bop.getInput().get(1), 1)) {
            Hop left = bop.getInput().get(0).getInput().get(0);
            Hop right = bop.getInput().get(1).getInput().get(0);
            // create new subdag (no in-place dag update to prevent anomalies with
            // multiple consumers during rewrite process)
            OpOp2 binop = (bop.getOp() == OpOp2.CBIND) ? OpOp2.RBIND : OpOp2.CBIND;
            BinaryOp bopnew = HopRewriteUtils.createBinary(left, right, binop);
            HopRewriteUtils.replaceChildReference(parent, hi, bopnew, pos);
            hi = bopnew;
            LOG.debug("Applied simplifyTransposedAppend (line " + hi.getBeginLine() + ").");
        }
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop) OpOp2(org.apache.sysml.hops.Hop.OpOp2) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 99 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method removeUnnecessaryVectorizeOperation.

private static Hop removeUnnecessaryVectorizeOperation(Hop hi) {
    // applies to all binary matrix operations, if one input is unnecessarily vectorized
    if (hi instanceof BinaryOp && hi.getDataType() == DataType.MATRIX && ((BinaryOp) hi).supportsMatrixScalarOperations()) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        if (// no outer
        !(left.getDim1() > 1 && left.getDim2() == 1 && right.getDim1() == 1 && right.getDim2() > 1)) {
            // check and remove right vectorized scalar
            if (left.getDataType() == DataType.MATRIX && right instanceof DataGenOp) {
                DataGenOp dright = (DataGenOp) right;
                if (dright.getOp() == DataGenMethod.RAND && dright.hasConstantValue()) {
                    Hop drightIn = dright.getInput().get(dright.getParamIndex(DataExpression.RAND_MIN));
                    HopRewriteUtils.replaceChildReference(bop, dright, drightIn, 1);
                    HopRewriteUtils.cleanupUnreferenced(dright);
                    LOG.debug("Applied removeUnnecessaryVectorizeOperation1");
                }
            } else // check and remove left vectorized scalar
            if (right.getDataType() == DataType.MATRIX && left instanceof DataGenOp) {
                DataGenOp dleft = (DataGenOp) left;
                if (dleft.getOp() == DataGenMethod.RAND && dleft.hasConstantValue() && (left.getDim2() == 1 || right.getDim2() > 1) && (left.getDim1() == 1 || right.getDim1() > 1)) {
                    Hop dleftIn = dleft.getInput().get(dleft.getParamIndex(DataExpression.RAND_MIN));
                    HopRewriteUtils.replaceChildReference(bop, dleft, dleftIn, 0);
                    HopRewriteUtils.cleanupUnreferenced(dleft);
                    LOG.debug("Applied removeUnnecessaryVectorizeOperation2");
                }
            }
        // Note: we applied this rewrite to at most one side in order to keep the
        // output semantically equivalent. However, future extensions might consider
        // to remove vectors from both side, compute the binary op on scalars and
        // finally feed it into a datagenop of the original dimensions.
        }
    }
    return hi;
}
Also used : DataGenOp(org.apache.sysml.hops.DataGenOp) Hop(org.apache.sysml.hops.Hop) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 100 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method fuseBinarySubDAGToUnaryOperation.

/**
 * handle simplification of more complex sub DAG to unary operation.
 *
 * X*(1-X) -> sprop(X)
 * (1-X)*X -> sprop(X)
 * 1/(1+exp(-X)) -> sigmoid(X)
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 */
private static Hop fuseBinarySubDAGToUnaryOperation(Hop parent, Hop hi, int pos) {
    if (hi instanceof BinaryOp) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = hi.getInput().get(0);
        Hop right = hi.getInput().get(1);
        boolean applied = false;
        // sample proportion (sprop) operator
        if (bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX) {
            if (// (1-X)*X
            left instanceof BinaryOp) {
                BinaryOp bleft = (BinaryOp) left;
                Hop left1 = bleft.getInput().get(0);
                Hop left2 = bleft.getInput().get(1);
                if (left1 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left1) == 1 && left2 == right && bleft.getOp() == OpOp2.MINUS) {
                    UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SPROP);
                    HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, left);
                    hi = unary;
                    applied = true;
                    LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop1");
                }
            }
            if (// X*(1-X)
            !applied && right instanceof BinaryOp) {
                BinaryOp bright = (BinaryOp) right;
                Hop right1 = bright.getInput().get(0);
                Hop right2 = bright.getInput().get(1);
                if (right1 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) right1) == 1 && right2 == left && bright.getOp() == OpOp2.MINUS) {
                    UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SPROP);
                    HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, left);
                    hi = unary;
                    applied = true;
                    LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop2");
                }
            }
        }
        // sigmoid operator
        if (!applied && bop.getOp() == OpOp2.DIV && left.getDataType() == DataType.SCALAR && right.getDataType() == DataType.MATRIX && left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left) == 1 && right instanceof BinaryOp) {
            // note: if there are multiple consumers on the intermediate,
            // we follow the heuristic that redundant computation is more beneficial,
            // i.e., we still fuse but leave the intermediate for the other consumers
            BinaryOp bop2 = (BinaryOp) right;
            Hop left2 = bop2.getInput().get(0);
            Hop right2 = bop2.getInput().get(1);
            if (bop2.getOp() == OpOp2.PLUS && left2.getDataType() == DataType.SCALAR && right2.getDataType() == DataType.MATRIX && left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left2) == 1 && right2 instanceof UnaryOp) {
                UnaryOp uop = (UnaryOp) right2;
                Hop uopin = uop.getInput().get(0);
                if (uop.getOp() == OpOp1.EXP) {
                    UnaryOp unary = null;
                    // Pattern 1: (1/(1 + exp(-X))
                    if (HopRewriteUtils.isBinary(uopin, OpOp2.MINUS)) {
                        BinaryOp bop3 = (BinaryOp) uopin;
                        Hop left3 = bop3.getInput().get(0);
                        Hop right3 = bop3.getInput().get(1);
                        if (left3 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left3) == 0)
                            unary = HopRewriteUtils.createUnary(right3, OpOp1.SIGMOID);
                    } else // Pattern 2: (1/(1 + exp(X)), e.g., where -(-X) has been removed by
                    // the 'remove unnecessary minus' rewrite --> reintroduce the minus
                    {
                        BinaryOp minus = HopRewriteUtils.createBinaryMinus(uopin);
                        unary = HopRewriteUtils.createUnary(minus, OpOp1.SIGMOID);
                    }
                    if (unary != null) {
                        HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
                        HopRewriteUtils.cleanupUnreferenced(bop, bop2, uop);
                        hi = unary;
                        applied = true;
                        LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sigmoid1");
                    }
                }
            }
        }
        // select positive (selp) operator (note: same initial pattern as sprop)
        if (!applied && bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX) {
            // to replace the X*tmp with selp(X) due to lower memory requirements and simply sparsity propagation
            if (// (X>0)*X
            left instanceof BinaryOp) {
                BinaryOp bleft = (BinaryOp) left;
                Hop left1 = bleft.getInput().get(0);
                Hop left2 = bleft.getInput().get(1);
                if (left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left2) == 0 && left1 == right && (bleft.getOp() == OpOp2.GREATER)) {
                    BinaryOp binary = HopRewriteUtils.createBinary(right, new LiteralOp(0), OpOp2.MAX);
                    HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, left);
                    hi = binary;
                    applied = true;
                    LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0a");
                }
            }
            if (// X*(X>0)
            !applied && right instanceof BinaryOp) {
                BinaryOp bright = (BinaryOp) right;
                Hop right1 = bright.getInput().get(0);
                Hop right2 = bright.getInput().get(1);
                if (right2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) right2) == 0 && right1 == left && bright.getOp() == OpOp2.GREATER) {
                    BinaryOp binary = HopRewriteUtils.createBinary(left, new LiteralOp(0), OpOp2.MAX);
                    HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, left);
                    hi = binary;
                    applied = true;
                    LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0b");
                }
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

Hop (org.apache.sysml.hops.Hop)307 LiteralOp (org.apache.sysml.hops.LiteralOp)94 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)65 BinaryOp (org.apache.sysml.hops.BinaryOp)63 ArrayList (java.util.ArrayList)61 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)61 HashMap (java.util.HashMap)44 DataOp (org.apache.sysml.hops.DataOp)41 UnaryOp (org.apache.sysml.hops.UnaryOp)41 HashSet (java.util.HashSet)39 ReorgOp (org.apache.sysml.hops.ReorgOp)32 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)28 StatementBlock (org.apache.sysml.parser.StatementBlock)28 IndexingOp (org.apache.sysml.hops.IndexingOp)24 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)23 WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)23 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)22 DataGenOp (org.apache.sysml.hops.DataGenOp)21 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)21 HopsException (org.apache.sysml.hops.HopsException)18