Search in sources :

Example 21 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method removeUnnecessaryVectorizeOperation.

private static Hop removeUnnecessaryVectorizeOperation(Hop hi) {
    // applies to all binary matrix operations, if one input is unnecessarily vectorized
    if (hi instanceof BinaryOp && hi.getDataType() == DataType.MATRIX && ((BinaryOp) hi).supportsMatrixScalarOperations()) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        if (// no outer
        !(left.getDim1() > 1 && left.getDim2() == 1 && right.getDim1() == 1 && right.getDim2() > 1)) {
            // check and remove right vectorized scalar
            if (left.getDataType() == DataType.MATRIX && right instanceof DataGenOp) {
                DataGenOp dright = (DataGenOp) right;
                if (dright.getOp() == DataGenMethod.RAND && dright.hasConstantValue()) {
                    Hop drightIn = dright.getInput().get(dright.getParamIndex(DataExpression.RAND_MIN));
                    HopRewriteUtils.replaceChildReference(bop, dright, drightIn, 1);
                    HopRewriteUtils.cleanupUnreferenced(dright);
                    LOG.debug("Applied removeUnnecessaryVectorizeOperation1");
                }
            } else // check and remove left vectorized scalar
            if (right.getDataType() == DataType.MATRIX && left instanceof DataGenOp) {
                DataGenOp dleft = (DataGenOp) left;
                if (dleft.getOp() == DataGenMethod.RAND && dleft.hasConstantValue() && (left.getDim2() == 1 || right.getDim2() > 1) && (left.getDim1() == 1 || right.getDim1() > 1)) {
                    Hop dleftIn = dleft.getInput().get(dleft.getParamIndex(DataExpression.RAND_MIN));
                    HopRewriteUtils.replaceChildReference(bop, dleft, dleftIn, 0);
                    HopRewriteUtils.cleanupUnreferenced(dleft);
                    LOG.debug("Applied removeUnnecessaryVectorizeOperation2");
                }
            }
        // Note: we applied this rewrite to at most one side in order to keep the
        // output semantically equivalent. However, future extensions might consider
        // to remove vectors from both side, compute the binary op on scalars and
        // finally feed it into a datagenop of the original dimensions.
        }
    }
    return hi;
}
Also used : DataGenOp(org.apache.sysml.hops.DataGenOp) Hop(org.apache.sysml.hops.Hop) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 22 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method fuseBinarySubDAGToUnaryOperation.

/**
 * handle simplification of more complex sub DAG to unary operation.
 *
 * X*(1-X) -> sprop(X)
 * (1-X)*X -> sprop(X)
 * 1/(1+exp(-X)) -> sigmoid(X)
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 */
private static Hop fuseBinarySubDAGToUnaryOperation(Hop parent, Hop hi, int pos) {
    if (hi instanceof BinaryOp) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = hi.getInput().get(0);
        Hop right = hi.getInput().get(1);
        boolean applied = false;
        // sample proportion (sprop) operator
        if (bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX) {
            if (// (1-X)*X
            left instanceof BinaryOp) {
                BinaryOp bleft = (BinaryOp) left;
                Hop left1 = bleft.getInput().get(0);
                Hop left2 = bleft.getInput().get(1);
                if (left1 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left1) == 1 && left2 == right && bleft.getOp() == OpOp2.MINUS) {
                    UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SPROP);
                    HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, left);
                    hi = unary;
                    applied = true;
                    LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop1");
                }
            }
            if (// X*(1-X)
            !applied && right instanceof BinaryOp) {
                BinaryOp bright = (BinaryOp) right;
                Hop right1 = bright.getInput().get(0);
                Hop right2 = bright.getInput().get(1);
                if (right1 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) right1) == 1 && right2 == left && bright.getOp() == OpOp2.MINUS) {
                    UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SPROP);
                    HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, left);
                    hi = unary;
                    applied = true;
                    LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop2");
                }
            }
        }
        // sigmoid operator
        if (!applied && bop.getOp() == OpOp2.DIV && left.getDataType() == DataType.SCALAR && right.getDataType() == DataType.MATRIX && left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left) == 1 && right instanceof BinaryOp) {
            // note: if there are multiple consumers on the intermediate,
            // we follow the heuristic that redundant computation is more beneficial,
            // i.e., we still fuse but leave the intermediate for the other consumers
            BinaryOp bop2 = (BinaryOp) right;
            Hop left2 = bop2.getInput().get(0);
            Hop right2 = bop2.getInput().get(1);
            if (bop2.getOp() == OpOp2.PLUS && left2.getDataType() == DataType.SCALAR && right2.getDataType() == DataType.MATRIX && left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left2) == 1 && right2 instanceof UnaryOp) {
                UnaryOp uop = (UnaryOp) right2;
                Hop uopin = uop.getInput().get(0);
                if (uop.getOp() == OpOp1.EXP) {
                    UnaryOp unary = null;
                    // Pattern 1: (1/(1 + exp(-X))
                    if (HopRewriteUtils.isBinary(uopin, OpOp2.MINUS)) {
                        BinaryOp bop3 = (BinaryOp) uopin;
                        Hop left3 = bop3.getInput().get(0);
                        Hop right3 = bop3.getInput().get(1);
                        if (left3 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left3) == 0)
                            unary = HopRewriteUtils.createUnary(right3, OpOp1.SIGMOID);
                    } else // Pattern 2: (1/(1 + exp(X)), e.g., where -(-X) has been removed by
                    // the 'remove unnecessary minus' rewrite --> reintroduce the minus
                    {
                        BinaryOp minus = HopRewriteUtils.createBinaryMinus(uopin);
                        unary = HopRewriteUtils.createUnary(minus, OpOp1.SIGMOID);
                    }
                    if (unary != null) {
                        HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
                        HopRewriteUtils.cleanupUnreferenced(bop, bop2, uop);
                        hi = unary;
                        applied = true;
                        LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sigmoid1");
                    }
                }
            }
        }
        // select positive (selp) operator (note: same initial pattern as sprop)
        if (!applied && bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX) {
            // to replace the X*tmp with selp(X) due to lower memory requirements and simply sparsity propagation
            if (// (X>0)*X
            left instanceof BinaryOp) {
                BinaryOp bleft = (BinaryOp) left;
                Hop left1 = bleft.getInput().get(0);
                Hop left2 = bleft.getInput().get(1);
                if (left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left2) == 0 && left1 == right && (bleft.getOp() == OpOp2.GREATER)) {
                    BinaryOp binary = HopRewriteUtils.createBinary(right, new LiteralOp(0), OpOp2.MAX);
                    HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, left);
                    hi = binary;
                    applied = true;
                    LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0a");
                }
            }
            if (// X*(X>0)
            !applied && right instanceof BinaryOp) {
                BinaryOp bright = (BinaryOp) right;
                Hop right1 = bright.getInput().get(0);
                Hop right2 = bright.getInput().get(1);
                if (right2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) right2) == 0 && right1 == left && bright.getOp() == OpOp2.GREATER) {
                    BinaryOp binary = HopRewriteUtils.createBinary(left, new LiteralOp(0), OpOp2.MAX);
                    HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, left);
                    hi = binary;
                    applied = true;
                    LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0b");
                }
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 23 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifyMultiBinaryToBinaryOperation.

private static Hop simplifyMultiBinaryToBinaryOperation(Hop hi) {
    // pattern: 1-(X*Y) --> X 1-* Y (avoid intermediate)
    if (HopRewriteUtils.isBinary(hi, OpOp2.MINUS) && hi.getDataType() == DataType.MATRIX && hi.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) hi.getInput().get(0)) == 1 && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) && // single consumer
    hi.getInput().get(1).getParent().size() == 1) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = hi.getInput().get(1).getInput().get(0);
        Hop right = hi.getInput().get(1).getInput().get(1);
        // set new binaryop type and rewire inputs
        bop.setOp(OpOp2.MINUS1_MULT);
        HopRewriteUtils.removeAllChildReferences(hi);
        HopRewriteUtils.addChildReference(bop, left);
        HopRewriteUtils.addChildReference(bop, right);
        LOG.debug("Applied simplifyMultiBinaryToBinaryOperation.");
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 24 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class DMLTranslator method processBooleanExpression.

private Hop processBooleanExpression(BooleanExpression source, DataIdentifier target, HashMap<String, Hop> hops) {
    // Boolean Not has a single parameter
    boolean constLeft = (source.getLeft().getOutput() instanceof ConstIdentifier);
    boolean constRight = false;
    if (source.getRight() != null) {
        constRight = (source.getRight().getOutput() instanceof ConstIdentifier);
    }
    if (constLeft || constRight) {
        LOG.error(source.printErrorLocation() + "Boolean expression with constant unsupported");
        throw new RuntimeException(source.printErrorLocation() + "Boolean expression with constant unsupported");
    }
    Hop left = processExpression(source.getLeft(), null, hops);
    Hop right = null;
    if (source.getRight() != null) {
        right = processExpression(source.getRight(), null, hops);
    }
    // (type should not be determined by target (e.g., string for print)
    if (target == null)
        target = createTarget(source);
    if (target.getDataType().isScalar())
        target.setValueType(ValueType.BOOLEAN);
    if (source.getRight() == null) {
        Hop currUop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.NOT, left);
        currUop.setParseInfo(source);
        return currUop;
    } else {
        Hop currBop = null;
        OpOp2 op = null;
        if (source.getOpCode() == Expression.BooleanOp.LOGICALAND) {
            op = OpOp2.AND;
        } else if (source.getOpCode() == Expression.BooleanOp.LOGICALOR) {
            op = OpOp2.OR;
        } else {
            LOG.error(source.printErrorLocation() + "Unknown boolean operation " + source.getOpCode());
            throw new RuntimeException(source.printErrorLocation() + "Unknown boolean operation " + source.getOpCode());
        }
        currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), op, left, right);
        currBop.setParseInfo(source);
        // setIdentifierParams(currBop,source.getOutput());
        return currBop;
    }
}
Also used : DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop) OpOp2(org.apache.sysml.hops.Hop.OpOp2) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 25 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class DMLTranslator method processBinaryExpression.

/**
 * Construct Hops from parse tree : Process Binary Expression in an
 * assignment statement
 *
 * @param source binary expression
 * @param target data identifier
 * @param hops map of high-level operators
 * @return high-level operator
 */
private Hop processBinaryExpression(BinaryExpression source, DataIdentifier target, HashMap<String, Hop> hops) {
    Hop left = processExpression(source.getLeft(), null, hops);
    Hop right = processExpression(source.getRight(), null, hops);
    if (left == null || right == null) {
        left = processExpression(source.getLeft(), null, hops);
        right = processExpression(source.getRight(), null, hops);
    }
    Hop currBop = null;
    // (type should not be determined by target (e.g., string for print)
    if (target == null) {
        target = createTarget(source);
    }
    target.setValueType(source.getOutput().getValueType());
    if (source.getOpCode() == Expression.BinaryOp.PLUS) {
        currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.PLUS, left, right);
    } else if (source.getOpCode() == Expression.BinaryOp.MINUS) {
        currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MINUS, left, right);
    } else if (source.getOpCode() == Expression.BinaryOp.MULT) {
        currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MULT, left, right);
    } else if (source.getOpCode() == Expression.BinaryOp.DIV) {
        currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.DIV, left, right);
    } else if (source.getOpCode() == Expression.BinaryOp.MODULUS) {
        currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MODULUS, left, right);
    } else if (source.getOpCode() == Expression.BinaryOp.INTDIV) {
        currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.INTDIV, left, right);
    } else if (source.getOpCode() == Expression.BinaryOp.MATMULT) {
        currBop = new AggBinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MULT, AggOp.SUM, left, right);
    } else if (source.getOpCode() == Expression.BinaryOp.POW) {
        currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.POW, left, right);
    } else {
        throw new ParseException("Unsupported parsing of binary expression: " + source.getOpCode());
    }
    setIdentifierParams(currBop, source.getOutput());
    currBop.setParseInfo(source);
    return currBop;
}
Also used : AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

BinaryOp (org.apache.sysml.hops.BinaryOp)58 Hop (org.apache.sysml.hops.Hop)57 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)50 LiteralOp (org.apache.sysml.hops.LiteralOp)27 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)23 UnaryOp (org.apache.sysml.hops.UnaryOp)18 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)9 IndexingOp (org.apache.sysml.hops.IndexingOp)6 ReorgOp (org.apache.sysml.hops.ReorgOp)6 TernaryOp (org.apache.sysml.hops.TernaryOp)6 ArrayList (java.util.ArrayList)5 DataOp (org.apache.sysml.hops.DataOp)5 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)5 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)5 DataGenOp (org.apache.sysml.hops.DataGenOp)4 HashMap (java.util.HashMap)3 LeftIndexingOp (org.apache.sysml.hops.LeftIndexingOp)3 CNode (org.apache.sysml.hops.codegen.cplan.CNode)3 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)3 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)3