Search in sources :

Example 36 with LiteralOp

use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifyDistributiveBinaryOperation.

/**
 * (X-Y*X) -> (1-Y)*X,    (Y*X-X) -> (Y-1)*X
 * (X+Y*X) -> (1+Y)*X,    (Y*X+X) -> (Y+1)*X
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop simplifyDistributiveBinaryOperation(Hop parent, Hop hi, int pos) {
    if (hi instanceof BinaryOp) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        // (X+Y*X) -> (1+Y)*X,    (Y*X+X) -> (Y+1)*X
        // (X-Y*X) -> (1-Y)*X,    (Y*X-X) -> (Y-1)*X
        boolean applied = false;
        if (left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX && HopRewriteUtils.isValidOp(bop.getOp(), LOOKUP_VALID_DISTRIBUTIVE_BINARY)) {
            Hop X = null;
            Hop Y = null;
            if (// (Y*X-X) -> (Y-1)*X
            HopRewriteUtils.isBinary(left, OpOp2.MULT)) {
                Hop leftC1 = left.getInput().get(0);
                Hop leftC2 = left.getInput().get(1);
                if (leftC1.getDataType() == DataType.MATRIX && leftC2.getDataType() == DataType.MATRIX && (right == leftC1 || right == leftC2) && leftC1 != leftC2) {
                    // any mult order
                    X = right;
                    Y = (right == leftC1) ? leftC2 : leftC1;
                }
                if (X != null) {
                    // rewrite 'binary +/-'
                    LiteralOp literal = new LiteralOp(1);
                    BinaryOp plus = HopRewriteUtils.createBinary(Y, literal, bop.getOp());
                    BinaryOp mult = HopRewriteUtils.createBinary(plus, X, OpOp2.MULT);
                    HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
                    HopRewriteUtils.cleanupUnreferenced(hi, left);
                    hi = mult;
                    applied = true;
                    LOG.debug("Applied simplifyDistributiveBinaryOperation1");
                }
            }
            if (// (X-Y*X) -> (1-Y)*X
            !applied && HopRewriteUtils.isBinary(right, OpOp2.MULT)) {
                Hop rightC1 = right.getInput().get(0);
                Hop rightC2 = right.getInput().get(1);
                if (rightC1.getDataType() == DataType.MATRIX && rightC2.getDataType() == DataType.MATRIX && (left == rightC1 || left == rightC2) && rightC1 != rightC2) {
                    // any mult order
                    X = left;
                    Y = (left == rightC1) ? rightC2 : rightC1;
                }
                if (X != null) {
                    // rewrite '+/- binary'
                    LiteralOp literal = new LiteralOp(1);
                    BinaryOp plus = HopRewriteUtils.createBinary(literal, Y, bop.getOp());
                    BinaryOp mult = HopRewriteUtils.createBinary(plus, X, OpOp2.MULT);
                    HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
                    HopRewriteUtils.cleanupUnreferenced(hi, right);
                    hi = mult;
                    LOG.debug("Applied simplifyDistributiveBinaryOperation2");
                }
            }
        }
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 37 with LiteralOp

use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method fuseBinarySubDAGToUnaryOperation.

/**
 * handle simplification of more complex sub DAG to unary operation.
 *
 * X*(1-X) -> sprop(X)
 * (1-X)*X -> sprop(X)
 * 1/(1+exp(-X)) -> sigmoid(X)
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 */
private static Hop fuseBinarySubDAGToUnaryOperation(Hop parent, Hop hi, int pos) {
    if (hi instanceof BinaryOp) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = hi.getInput().get(0);
        Hop right = hi.getInput().get(1);
        boolean applied = false;
        // sample proportion (sprop) operator
        if (bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX) {
            if (// (1-X)*X
            left instanceof BinaryOp) {
                BinaryOp bleft = (BinaryOp) left;
                Hop left1 = bleft.getInput().get(0);
                Hop left2 = bleft.getInput().get(1);
                if (left1 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left1) == 1 && left2 == right && bleft.getOp() == OpOp2.MINUS) {
                    UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SPROP);
                    HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, left);
                    hi = unary;
                    applied = true;
                    LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop1");
                }
            }
            if (// X*(1-X)
            !applied && right instanceof BinaryOp) {
                BinaryOp bright = (BinaryOp) right;
                Hop right1 = bright.getInput().get(0);
                Hop right2 = bright.getInput().get(1);
                if (right1 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) right1) == 1 && right2 == left && bright.getOp() == OpOp2.MINUS) {
                    UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SPROP);
                    HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, left);
                    hi = unary;
                    applied = true;
                    LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop2");
                }
            }
        }
        // sigmoid operator
        if (!applied && bop.getOp() == OpOp2.DIV && left.getDataType() == DataType.SCALAR && right.getDataType() == DataType.MATRIX && left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left) == 1 && right instanceof BinaryOp) {
            // note: if there are multiple consumers on the intermediate,
            // we follow the heuristic that redundant computation is more beneficial,
            // i.e., we still fuse but leave the intermediate for the other consumers
            BinaryOp bop2 = (BinaryOp) right;
            Hop left2 = bop2.getInput().get(0);
            Hop right2 = bop2.getInput().get(1);
            if (bop2.getOp() == OpOp2.PLUS && left2.getDataType() == DataType.SCALAR && right2.getDataType() == DataType.MATRIX && left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left2) == 1 && right2 instanceof UnaryOp) {
                UnaryOp uop = (UnaryOp) right2;
                Hop uopin = uop.getInput().get(0);
                if (uop.getOp() == OpOp1.EXP) {
                    UnaryOp unary = null;
                    // Pattern 1: (1/(1 + exp(-X))
                    if (HopRewriteUtils.isBinary(uopin, OpOp2.MINUS)) {
                        BinaryOp bop3 = (BinaryOp) uopin;
                        Hop left3 = bop3.getInput().get(0);
                        Hop right3 = bop3.getInput().get(1);
                        if (left3 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left3) == 0)
                            unary = HopRewriteUtils.createUnary(right3, OpOp1.SIGMOID);
                    } else // Pattern 2: (1/(1 + exp(X)), e.g., where -(-X) has been removed by
                    // the 'remove unnecessary minus' rewrite --> reintroduce the minus
                    {
                        BinaryOp minus = HopRewriteUtils.createBinaryMinus(uopin);
                        unary = HopRewriteUtils.createUnary(minus, OpOp1.SIGMOID);
                    }
                    if (unary != null) {
                        HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
                        HopRewriteUtils.cleanupUnreferenced(bop, bop2, uop);
                        hi = unary;
                        applied = true;
                        LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sigmoid1");
                    }
                }
            }
        }
        // select positive (selp) operator (note: same initial pattern as sprop)
        if (!applied && bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX) {
            // to replace the X*tmp with selp(X) due to lower memory requirements and simply sparsity propagation
            if (// (X>0)*X
            left instanceof BinaryOp) {
                BinaryOp bleft = (BinaryOp) left;
                Hop left1 = bleft.getInput().get(0);
                Hop left2 = bleft.getInput().get(1);
                if (left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left2) == 0 && left1 == right && (bleft.getOp() == OpOp2.GREATER)) {
                    BinaryOp binary = HopRewriteUtils.createBinary(right, new LiteralOp(0), OpOp2.MAX);
                    HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, left);
                    hi = binary;
                    applied = true;
                    LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0a");
                }
            }
            if (// X*(X>0)
            !applied && right instanceof BinaryOp) {
                BinaryOp bright = (BinaryOp) right;
                Hop right1 = bright.getInput().get(0);
                Hop right2 = bright.getInput().get(1);
                if (right2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) right2) == 0 && right1 == left && bright.getOp() == OpOp2.GREATER) {
                    BinaryOp binary = HopRewriteUtils.createBinary(left, new LiteralOp(0), OpOp2.MAX);
                    HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, left);
                    hi = binary;
                    applied = true;
                    LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0b");
                }
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 38 with LiteralOp

use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifyMultiBinaryToBinaryOperation.

private static Hop simplifyMultiBinaryToBinaryOperation(Hop hi) {
    // pattern: 1-(X*Y) --> X 1-* Y (avoid intermediate)
    if (HopRewriteUtils.isBinary(hi, OpOp2.MINUS) && hi.getDataType() == DataType.MATRIX && hi.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) hi.getInput().get(0)) == 1 && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) && // single consumer
    hi.getInput().get(1).getParent().size() == 1) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = hi.getInput().get(1).getInput().get(0);
        Hop right = hi.getInput().get(1).getInput().get(1);
        // set new binaryop type and rewire inputs
        bop.setOp(OpOp2.MINUS1_MULT);
        HopRewriteUtils.removeAllChildReferences(hi);
        HopRewriteUtils.addChildReference(bop, left);
        HopRewriteUtils.addChildReference(bop, right);
        LOG.debug("Applied simplifyMultiBinaryToBinaryOperation.");
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 39 with LiteralOp

use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method fuseOrderOperationChain.

private static Hop fuseOrderOperationChain(Hop hi) {
    // order(order(X,2),1) -> order(X, (12)),
    if (HopRewriteUtils.isReorg(hi, ReOrgOp.SORT) && // scalar by
    hi.getInput().get(1) instanceof LiteralOp && // scalar desc
    hi.getInput().get(2) instanceof LiteralOp && // not ixret
    HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false) && !OptimizerUtils.isHadoopExecutionMode()) {
        LiteralOp by = (LiteralOp) hi.getInput().get(1);
        boolean desc = HopRewriteUtils.getBooleanValue((LiteralOp) hi.getInput().get(2));
        // find chain of order operations with same desc/ixret configuration and single consumers
        ArrayList<LiteralOp> byList = new ArrayList<LiteralOp>();
        byList.add(by);
        Hop input = hi.getInput().get(0);
        while (HopRewriteUtils.isReorg(input, ReOrgOp.SORT) && // scalar by
        input.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.isLiteralOfValue(input.getInput().get(2), desc) && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false) && input.getParent().size() == 1) {
            byList.add((LiteralOp) input.getInput().get(1));
            input = input.getInput().get(0);
        }
        // merge order chain if at least two instances
        if (byList.size() >= 2) {
            // create new order operations
            ArrayList<Hop> inputs = new ArrayList<>();
            inputs.add(input);
            inputs.add(HopRewriteUtils.createDataGenOpByVal(byList, 1, byList.size()));
            inputs.add(new LiteralOp(desc));
            inputs.add(new LiteralOp(false));
            Hop hnew = HopRewriteUtils.createReorg(inputs, ReOrgOp.SORT);
            // cleanup references recursively
            Hop current = hi;
            while (current != input) {
                Hop tmp = current.getInput().get(0);
                HopRewriteUtils.removeAllChildReferences(current);
                current = tmp;
            }
            // rewire all parents (avoid anomalies with replicated datagen)
            List<Hop> parents = new ArrayList<>(hi.getParent());
            for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, hi, hnew);
            hi = hnew;
            LOG.debug("Applied fuseOrderOperationChain (line " + hi.getBeginLine() + ").");
        }
    }
    return hi;
}
Also used : ArrayList(java.util.ArrayList) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 40 with LiteralOp

use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.

the class DMLTranslator method processIndexingExpression.

private Hop processIndexingExpression(IndexedIdentifier source, DataIdentifier target, HashMap<String, Hop> hops) {
    // process Hops for indexes (for source)
    Hop rowLowerHops = null, rowUpperHops = null, colLowerHops = null, colUpperHops = null;
    if (source.getRowLowerBound() != null)
        rowLowerHops = processExpression(source.getRowLowerBound(), null, hops);
    else
        rowLowerHops = new LiteralOp(1);
    if (source.getRowUpperBound() != null)
        rowUpperHops = processExpression(source.getRowUpperBound(), null, hops);
    else {
        if (source.getOrigDim1() != -1)
            rowUpperHops = new LiteralOp(source.getOrigDim1());
        else {
            rowUpperHops = new UnaryOp(source.getName(), DataType.SCALAR, ValueType.INT, Hop.OpOp1.NROW, hops.get(source.getName()));
            rowUpperHops.setParseInfo(source);
        }
    }
    if (source.getColLowerBound() != null)
        colLowerHops = processExpression(source.getColLowerBound(), null, hops);
    else
        colLowerHops = new LiteralOp(1);
    if (source.getColUpperBound() != null)
        colUpperHops = processExpression(source.getColUpperBound(), null, hops);
    else {
        if (source.getOrigDim2() != -1)
            colUpperHops = new LiteralOp(source.getOrigDim2());
        else
            colUpperHops = new UnaryOp(source.getName(), DataType.SCALAR, ValueType.INT, Hop.OpOp1.NCOL, hops.get(source.getName()));
    }
    if (target == null) {
        target = createTarget(source);
    }
    // unknown nnz after range indexing (applies to indexing op but also
    // data dependent operations)
    target.setNnz(-1);
    Hop indexOp = new IndexingOp(target.getName(), target.getDataType(), target.getValueType(), hops.get(source.getName()), rowLowerHops, rowUpperHops, colLowerHops, colUpperHops, source.getRowLowerEqualsUpper(), source.getColLowerEqualsUpper());
    indexOp.setParseInfo(target);
    setIdentifierParams(indexOp, target);
    return indexOp;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) IndexingOp(org.apache.sysml.hops.IndexingOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Aggregations

LiteralOp (org.apache.sysml.hops.LiteralOp)102 Hop (org.apache.sysml.hops.Hop)88 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)30 BinaryOp (org.apache.sysml.hops.BinaryOp)27 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)25 UnaryOp (org.apache.sysml.hops.UnaryOp)24 DataOp (org.apache.sysml.hops.DataOp)23 IndexingOp (org.apache.sysml.hops.IndexingOp)17 ArrayList (java.util.ArrayList)13 DataGenOp (org.apache.sysml.hops.DataGenOp)13 LeftIndexingOp (org.apache.sysml.hops.LeftIndexingOp)12 HashMap (java.util.HashMap)11 DataIdentifier (org.apache.sysml.parser.DataIdentifier)10 ReorgOp (org.apache.sysml.hops.ReorgOp)9 MatrixObject (org.apache.sysml.runtime.controlprogram.caching.MatrixObject)9 HopsException (org.apache.sysml.hops.HopsException)8 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)8 StatementBlock (org.apache.sysml.parser.StatementBlock)8 TernaryOp (org.apache.sysml.hops.TernaryOp)7 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)6