Search in sources :

Example 16 with UnaryOp

use of org.apache.sysml.hops.UnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyDotProductSum.

/**
 * NOTE: dot-product-sum could be also applied to sum(a*b). However, we
 * restrict ourselfs to sum(a^2) and transitively sum(a*a) since a general mm
 * a%*%b on MR can be also counter-productive (e.g., MMCJ) while tsmm is always
 * beneficial.
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop simplifyDotProductSum(Hop parent, Hop hi, int pos) {
    // w/o materialization of intermediates
    if (// sum
    hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM && // full aggregate
    ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // vector (for correctness)
    hi.getInput().get(0).getDim2() == 1) {
        Hop baLeft = null;
        Hop baRight = null;
        // check for ^2 w/o multiple consumers
        Hop hi2 = hi.getInput().get(0);
        // check for sum(v^2), might have been rewritten from sum(v*v)
        if (HopRewriteUtils.isBinary(hi2, OpOp2.POW) && hi2.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) hi2.getInput().get(1)) == 2 && // no other consumer than sum
        hi2.getParent().size() == 1) {
            Hop input = hi2.getInput().get(0);
            baLeft = input;
            baRight = input;
        } else // check for sum(v1*v2), but prevent to rewrite sum(v1*v2*v3) which is later compiled into a ta+* lop
        if (// no other consumer than sum
        HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1) && hi2.getInput().get(0).getDim2() == 1 && hi2.getInput().get(1).getDim2() == 1 && !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT) && !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT) && (!ALLOW_SUM_PRODUCT_REWRITES || !(// do not rewrite (A^2)*B
        HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW) && // let tak+* handle it
        hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp && ((LiteralOp) hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2)) && (!ALLOW_SUM_PRODUCT_REWRITES || !(// do not rewrite B*(A^2)
        HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW) && // let tak+* handle it
        hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp && ((LiteralOp) hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2))) {
            baLeft = hi2.getInput().get(0);
            baRight = hi2.getInput().get(1);
        }
        // perform actual rewrite (if necessary)
        if (baLeft != null && baRight != null) {
            // create new operator chain
            ReorgOp trans = HopRewriteUtils.createTranspose(baLeft);
            AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, baRight);
            UnaryOp cast = HopRewriteUtils.createUnary(mmult, OpOp1.CAST_AS_SCALAR);
            // rehang new subdag under parent node
            HopRewriteUtils.replaceChildReference(parent, hi, cast, pos);
            HopRewriteUtils.cleanupUnreferenced(hi, hi2);
            hi = cast;
            LOG.debug("Applied simplifyDotProductSum.");
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 17 with UnaryOp

use of org.apache.sysml.hops.UnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyNrowNcolComputation.

private static Hop simplifyNrowNcolComputation(Hop parent, Hop hi, int pos) {
    // even if the intermediate is otherwise not required, e.g., when part of a fused operator)
    if (hi instanceof UnaryOp) {
        if (((UnaryOp) hi).getOp() == OpOp1.NROW && hi.getInput().get(0).rowsKnown()) {
            Hop hnew = new LiteralOp(hi.getInput().get(0).getDim1());
            HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos, false);
            HopRewriteUtils.cleanupUnreferenced(hi);
            LOG.debug("Applied simplifyNrowComputation nrow(" + hi.getHopID() + ") -> " + hnew.getName() + " (line " + hi.getBeginLine() + ").");
            hi = hnew;
        } else if (((UnaryOp) hi).getOp() == OpOp1.NCOL && hi.getInput().get(0).colsKnown()) {
            Hop hnew = new LiteralOp(hi.getInput().get(0).getDim2());
            HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos, false);
            HopRewriteUtils.cleanupUnreferenced(hi);
            LOG.debug("Applied simplifyNcolComputation ncol(" + hi.getHopID() + ") -> " + hnew.getName() + " (line " + hi.getBeginLine() + ").");
            hi = hnew;
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 18 with UnaryOp

use of org.apache.sysml.hops.UnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifyUnaryPPredOperation.

private static Hop simplifyUnaryPPredOperation(Hop parent, Hop hi, int pos) {
    if (// unaryop
    hi instanceof UnaryOp && hi.getDataType() == DataType.MATRIX && // binaryop - ppred
    hi.getInput().get(0) instanceof BinaryOp && ((BinaryOp) hi.getInput().get(0)).isPPredOperation()) {
        // valid unary op
        UnaryOp uop = (UnaryOp) hi;
        if (uop.getOp() == OpOp1.ABS || uop.getOp() == OpOp1.SIGN || uop.getOp() == OpOp1.CEIL || uop.getOp() == OpOp1.FLOOR || uop.getOp() == OpOp1.ROUND) {
            // clear link unary-binary
            Hop input = uop.getInput().get(0);
            HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
            HopRewriteUtils.cleanupUnreferenced(hi);
            hi = input;
            LOG.debug("Applied simplifyUnaryPPredOperation.");
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 19 with UnaryOp

use of org.apache.sysml.hops.UnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifyBinaryToUnaryOperation.

/**
 * Handle simplification of binary operations (relies on previous common subexpression elimination).
 * At the same time this servers as a canonicalization for more complex rewrites.
 *
 * X+X -> X*2, X*X -> X^2, (X>0)-(X<0) -> sign(X)
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop simplifyBinaryToUnaryOperation(Hop parent, Hop hi, int pos) {
    if (hi instanceof BinaryOp) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = hi.getInput().get(0);
        Hop right = hi.getInput().get(1);
        // patterns: X+X -> X*2, X*X -> X^2,
        if (left == right && left.getDataType() == DataType.MATRIX) {
            // however, we later compile specific LOPS for X*2 and X^2
            if (// X+X -> X*2
            bop.getOp() == OpOp2.PLUS) {
                bop.setOp(OpOp2.MULT);
                HopRewriteUtils.replaceChildReference(hi, right, new LiteralOp(2), 1);
                LOG.debug("Applied simplifyBinaryToUnaryOperation1 (line " + hi.getBeginLine() + ").");
            } else if (// X*X -> X^2
            bop.getOp() == OpOp2.MULT) {
                bop.setOp(OpOp2.POW);
                HopRewriteUtils.replaceChildReference(hi, right, new LiteralOp(2), 1);
                LOG.debug("Applied simplifyBinaryToUnaryOperation2 (line " + hi.getBeginLine() + ").");
            }
        } else // patterns: (X>0)-(X<0) -> sign(X)
        if (bop.getOp() == OpOp2.MINUS && HopRewriteUtils.isBinary(left, OpOp2.GREATER) && HopRewriteUtils.isBinary(right, OpOp2.LESS) && left.getInput().get(0) == right.getInput().get(0) && left.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left.getInput().get(1)) == 0 && right.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) right.getInput().get(1)) == 0) {
            UnaryOp uop = HopRewriteUtils.createUnary(left.getInput().get(0), OpOp1.SIGN);
            HopRewriteUtils.replaceChildReference(parent, hi, uop, pos);
            HopRewriteUtils.cleanupUnreferenced(hi, left, right);
            hi = uop;
            LOG.debug("Applied simplifyBinaryToUnaryOperation3 (line " + hi.getBeginLine() + ").");
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 20 with UnaryOp

use of org.apache.sysml.hops.UnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method fuseBinarySubDAGToUnaryOperation.

/**
 * handle simplification of more complex sub DAG to unary operation.
 *
 * X*(1-X) -> sprop(X)
 * (1-X)*X -> sprop(X)
 * 1/(1+exp(-X)) -> sigmoid(X)
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 */
private static Hop fuseBinarySubDAGToUnaryOperation(Hop parent, Hop hi, int pos) {
    if (hi instanceof BinaryOp) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = hi.getInput().get(0);
        Hop right = hi.getInput().get(1);
        boolean applied = false;
        // sample proportion (sprop) operator
        if (bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX) {
            if (// (1-X)*X
            left instanceof BinaryOp) {
                BinaryOp bleft = (BinaryOp) left;
                Hop left1 = bleft.getInput().get(0);
                Hop left2 = bleft.getInput().get(1);
                if (left1 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left1) == 1 && left2 == right && bleft.getOp() == OpOp2.MINUS) {
                    UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SPROP);
                    HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, left);
                    hi = unary;
                    applied = true;
                    LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop1");
                }
            }
            if (// X*(1-X)
            !applied && right instanceof BinaryOp) {
                BinaryOp bright = (BinaryOp) right;
                Hop right1 = bright.getInput().get(0);
                Hop right2 = bright.getInput().get(1);
                if (right1 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) right1) == 1 && right2 == left && bright.getOp() == OpOp2.MINUS) {
                    UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SPROP);
                    HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, left);
                    hi = unary;
                    applied = true;
                    LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop2");
                }
            }
        }
        // sigmoid operator
        if (!applied && bop.getOp() == OpOp2.DIV && left.getDataType() == DataType.SCALAR && right.getDataType() == DataType.MATRIX && left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left) == 1 && right instanceof BinaryOp) {
            // note: if there are multiple consumers on the intermediate,
            // we follow the heuristic that redundant computation is more beneficial,
            // i.e., we still fuse but leave the intermediate for the other consumers
            BinaryOp bop2 = (BinaryOp) right;
            Hop left2 = bop2.getInput().get(0);
            Hop right2 = bop2.getInput().get(1);
            if (bop2.getOp() == OpOp2.PLUS && left2.getDataType() == DataType.SCALAR && right2.getDataType() == DataType.MATRIX && left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left2) == 1 && right2 instanceof UnaryOp) {
                UnaryOp uop = (UnaryOp) right2;
                Hop uopin = uop.getInput().get(0);
                if (uop.getOp() == OpOp1.EXP) {
                    UnaryOp unary = null;
                    // Pattern 1: (1/(1 + exp(-X))
                    if (HopRewriteUtils.isBinary(uopin, OpOp2.MINUS)) {
                        BinaryOp bop3 = (BinaryOp) uopin;
                        Hop left3 = bop3.getInput().get(0);
                        Hop right3 = bop3.getInput().get(1);
                        if (left3 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left3) == 0)
                            unary = HopRewriteUtils.createUnary(right3, OpOp1.SIGMOID);
                    } else // Pattern 2: (1/(1 + exp(X)), e.g., where -(-X) has been removed by
                    // the 'remove unnecessary minus' rewrite --> reintroduce the minus
                    {
                        BinaryOp minus = HopRewriteUtils.createBinaryMinus(uopin);
                        unary = HopRewriteUtils.createUnary(minus, OpOp1.SIGMOID);
                    }
                    if (unary != null) {
                        HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
                        HopRewriteUtils.cleanupUnreferenced(bop, bop2, uop);
                        hi = unary;
                        applied = true;
                        LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sigmoid1");
                    }
                }
            }
        }
        // select positive (selp) operator (note: same initial pattern as sprop)
        if (!applied && bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX) {
            // to replace the X*tmp with selp(X) due to lower memory requirements and simply sparsity propagation
            if (// (X>0)*X
            left instanceof BinaryOp) {
                BinaryOp bleft = (BinaryOp) left;
                Hop left1 = bleft.getInput().get(0);
                Hop left2 = bleft.getInput().get(1);
                if (left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left2) == 0 && left1 == right && (bleft.getOp() == OpOp2.GREATER)) {
                    BinaryOp binary = HopRewriteUtils.createBinary(right, new LiteralOp(0), OpOp2.MAX);
                    HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, left);
                    hi = binary;
                    applied = true;
                    LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0a");
                }
            }
            if (// X*(X>0)
            !applied && right instanceof BinaryOp) {
                BinaryOp bright = (BinaryOp) right;
                Hop right1 = bright.getInput().get(0);
                Hop right2 = bright.getInput().get(1);
                if (right2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) right2) == 0 && right1 == left && bright.getOp() == OpOp2.GREATER) {
                    BinaryOp binary = HopRewriteUtils.createBinary(left, new LiteralOp(0), OpOp2.MAX);
                    HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, left);
                    hi = binary;
                    applied = true;
                    LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0b");
                }
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

UnaryOp (org.apache.sysml.hops.UnaryOp)42 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)39 Hop (org.apache.sysml.hops.Hop)35 LiteralOp (org.apache.sysml.hops.LiteralOp)24 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)18 BinaryOp (org.apache.sysml.hops.BinaryOp)18 DataOp (org.apache.sysml.hops.DataOp)9 IndexingOp (org.apache.sysml.hops.IndexingOp)8 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)6 TernaryOp (org.apache.sysml.hops.TernaryOp)6 ArrayList (java.util.ArrayList)5 HashMap (java.util.HashMap)5 ReorgOp (org.apache.sysml.hops.ReorgOp)5 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)5 DataGenOp (org.apache.sysml.hops.DataGenOp)4 HopsException (org.apache.sysml.hops.HopsException)4 LeftIndexingOp (org.apache.sysml.hops.LeftIndexingOp)4 MatrixObject (org.apache.sysml.runtime.controlprogram.caching.MatrixObject)4 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)3 CNode (org.apache.sysml.hops.codegen.cplan.CNode)3