Search in sources :

Example 51 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifyTraceMatrixMult.

private static Hop simplifyTraceMatrixMult(Hop parent, Hop hi, int pos) {
    if (// trace()
    hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.TRACE) {
        Hop hi2 = hi.getInput().get(0);
        if (// X%*%Y
        HopRewriteUtils.isMatrixMultiply(hi2)) {
            Hop left = hi2.getInput().get(0);
            Hop right = hi2.getInput().get(1);
            // create new operators (incl refresh size inside for transpose)
            ReorgOp trans = HopRewriteUtils.createTranspose(right);
            BinaryOp mult = HopRewriteUtils.createBinary(left, trans, OpOp2.MULT);
            AggUnaryOp sum = HopRewriteUtils.createSum(mult);
            // rehang new subdag under parent node
            HopRewriteUtils.replaceChildReference(parent, hi, sum, pos);
            HopRewriteUtils.cleanupUnreferenced(hi, hi2);
            hi = sum;
            LOG.debug("Applied simplifyTraceMatrixMult");
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 52 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method fuseDatagenAndBinaryOperation.

/**
 * Handle removal of unnecessary binary operations over rand data
 *
 * rand*7 -> rand(min*7,max*7); rand+7 -> rand(min+7,max+7); rand-7 -> rand(min+(-7),max+(-7))
 * 7*rand -> rand(min*7,max*7); 7+rand -> rand(min+7,max+7);
 *
 * @param hi high-order operation
 * @return high-level operator
 */
@SuppressWarnings("incomplete-switch")
private static Hop fuseDatagenAndBinaryOperation(Hop hi) {
    if (hi instanceof BinaryOp) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        // left input rand and hence output matrix double, right scalar literal
        if (HopRewriteUtils.isDataGenOp(left, DataGenMethod.RAND) && right instanceof LiteralOp && left.getParent().size() == 1) {
            DataGenOp inputGen = (DataGenOp) left;
            Hop pdf = inputGen.getInput(DataExpression.RAND_PDF);
            Hop min = inputGen.getInput(DataExpression.RAND_MIN);
            Hop max = inputGen.getInput(DataExpression.RAND_MAX);
            double sval = ((LiteralOp) right).getDoubleValue();
            boolean pdfUniform = pdf instanceof LiteralOp && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp) pdf).getStringValue());
            if (HopRewriteUtils.isBinary(bop, OpOp2.MULT, OpOp2.PLUS, OpOp2.MINUS, OpOp2.DIV) && min instanceof LiteralOp && max instanceof LiteralOp && pdfUniform) {
                // create fused data gen operator
                DataGenOp gen = null;
                switch(// fuse via scale and shift
                bop.getOp()) {
                    case MULT:
                        gen = HopRewriteUtils.copyDataGenOp(inputGen, sval, 0);
                        break;
                    case PLUS:
                    case MINUS:
                        gen = HopRewriteUtils.copyDataGenOp(inputGen, 1, sval * ((bop.getOp() == OpOp2.MINUS) ? -1 : 1));
                        break;
                    case DIV:
                        gen = HopRewriteUtils.copyDataGenOp(inputGen, 1 / sval, 0);
                        break;
                }
                // rewire all parents (avoid anomalies with replicated datagen)
                List<Hop> parents = new ArrayList<>(bop.getParent());
                for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, bop, gen);
                hi = gen;
                LOG.debug("Applied fuseDatagenAndBinaryOperation1 " + "(" + bop.getFilename() + ", line " + bop.getBeginLine() + ").");
            }
        } else // right input rand and hence output matrix double, left scalar literal
        if (right instanceof DataGenOp && ((DataGenOp) right).getOp() == DataGenMethod.RAND && left instanceof LiteralOp && right.getParent().size() == 1) {
            DataGenOp inputGen = (DataGenOp) right;
            Hop pdf = inputGen.getInput(DataExpression.RAND_PDF);
            Hop min = inputGen.getInput(DataExpression.RAND_MIN);
            Hop max = inputGen.getInput(DataExpression.RAND_MAX);
            double sval = ((LiteralOp) left).getDoubleValue();
            boolean pdfUniform = pdf instanceof LiteralOp && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp) pdf).getStringValue());
            if ((bop.getOp() == OpOp2.MULT || bop.getOp() == OpOp2.PLUS) && min instanceof LiteralOp && max instanceof LiteralOp && pdfUniform) {
                // create fused data gen operator
                DataGenOp gen = null;
                if (bop.getOp() == OpOp2.MULT)
                    gen = HopRewriteUtils.copyDataGenOp(inputGen, sval, 0);
                else {
                    // OpOp2.PLUS
                    gen = HopRewriteUtils.copyDataGenOp(inputGen, 1, sval);
                }
                // rewire all parents (avoid anomalies with replicated datagen)
                List<Hop> parents = new ArrayList<>(bop.getParent());
                for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, bop, gen);
                hi = gen;
                LOG.debug("Applied fuseDatagenAndBinaryOperation2 " + "(" + bop.getFilename() + ", line " + bop.getBeginLine() + ").");
            }
        } else // left input rand and hence output matrix double, right scalar variable
        if (HopRewriteUtils.isDataGenOp(left, DataGenMethod.RAND) && right.getDataType().isScalar() && left.getParent().size() == 1) {
            DataGenOp gen = (DataGenOp) left;
            Hop min = gen.getInput(DataExpression.RAND_MIN);
            Hop max = gen.getInput(DataExpression.RAND_MAX);
            Hop pdf = gen.getInput(DataExpression.RAND_PDF);
            boolean pdfUniform = pdf instanceof LiteralOp && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp) pdf).getStringValue());
            if (HopRewriteUtils.isBinary(bop, OpOp2.PLUS) && HopRewriteUtils.isLiteralOfValue(min, 0) && HopRewriteUtils.isLiteralOfValue(max, 0)) {
                gen.setInput(DataExpression.RAND_MIN, right);
                gen.setInput(DataExpression.RAND_MAX, right);
                // rewire all parents (avoid anomalies with replicated datagen)
                List<Hop> parents = new ArrayList<>(bop.getParent());
                for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, bop, gen);
                hi = gen;
                LOG.debug("Applied fuseDatagenAndBinaryOperation3a " + "(" + bop.getFilename() + ", line " + bop.getBeginLine() + ").");
            } else if (HopRewriteUtils.isBinary(bop, OpOp2.MULT) && ((HopRewriteUtils.isLiteralOfValue(min, 0) && pdfUniform) || HopRewriteUtils.isLiteralOfValue(min, 1)) && HopRewriteUtils.isLiteralOfValue(max, 1)) {
                if (HopRewriteUtils.isLiteralOfValue(min, 1))
                    gen.setInput(DataExpression.RAND_MIN, right);
                gen.setInput(DataExpression.RAND_MAX, right);
                // rewire all parents (avoid anomalies with replicated datagen)
                List<Hop> parents = new ArrayList<>(bop.getParent());
                for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, bop, gen);
                hi = gen;
                LOG.debug("Applied fuseDatagenAndBinaryOperation3b " + "(" + bop.getFilename() + ", line " + bop.getBeginLine() + ").");
            }
        }
    }
    return hi;
}
Also used : DataGenOp(org.apache.sysml.hops.DataGenOp) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) ArrayList(java.util.ArrayList) List(java.util.List) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 53 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method removeUnnecessaryBinaryOperation.

/**
 * handle removal of unnecessary binary operations
 *
 * X/1 or X*1 or 1*X or X-0 -> X
 * -1*X or X*-1-> -X
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop removeUnnecessaryBinaryOperation(Hop parent, Hop hi, int pos) {
    if (hi instanceof BinaryOp) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        // X/1 or X*1 -> X
        if (left.getDataType() == DataType.MATRIX && right instanceof LiteralOp && ((LiteralOp) right).getDoubleValue() == 1.0) {
            if (bop.getOp() == OpOp2.DIV || bop.getOp() == OpOp2.MULT) {
                HopRewriteUtils.replaceChildReference(parent, bop, left, pos);
                hi = left;
                LOG.debug("Applied removeUnnecessaryBinaryOperation1 (line " + bop.getBeginLine() + ")");
            }
        } else // X-0 -> X
        if (left.getDataType() == DataType.MATRIX && right instanceof LiteralOp && ((LiteralOp) right).getDoubleValue() == 0.0) {
            if (bop.getOp() == OpOp2.MINUS) {
                HopRewriteUtils.replaceChildReference(parent, bop, left, pos);
                hi = left;
                LOG.debug("Applied removeUnnecessaryBinaryOperation2 (line " + bop.getBeginLine() + ")");
            }
        } else // 1*X -> X
        if (right.getDataType() == DataType.MATRIX && left instanceof LiteralOp && ((LiteralOp) left).getDoubleValue() == 1.0) {
            if (bop.getOp() == OpOp2.MULT) {
                HopRewriteUtils.replaceChildReference(parent, bop, right, pos);
                hi = right;
                LOG.debug("Applied removeUnnecessaryBinaryOperation3 (line " + bop.getBeginLine() + ")");
            }
        } else // -X to -1*X due to mechanical reasons
        if (right.getDataType() == DataType.MATRIX && left instanceof LiteralOp && ((LiteralOp) left).getDoubleValue() == -1.0) {
            if (bop.getOp() == OpOp2.MULT) {
                bop.setOp(OpOp2.MINUS);
                HopRewriteUtils.replaceChildReference(bop, left, new LiteralOp(0), 0);
                hi = bop;
                LOG.debug("Applied removeUnnecessaryBinaryOperation4 (line " + bop.getBeginLine() + ")");
            }
        } else // X*-1 -> -X (see comment above)
        if (left.getDataType() == DataType.MATRIX && right instanceof LiteralOp && ((LiteralOp) right).getDoubleValue() == -1.0) {
            if (bop.getOp() == OpOp2.MULT) {
                bop.setOp(OpOp2.MINUS);
                HopRewriteUtils.removeChildReferenceByPos(bop, right, 1);
                HopRewriteUtils.addChildReference(bop, new LiteralOp(0), 0);
                hi = bop;
                LOG.debug("Applied removeUnnecessaryBinaryOperation5 (line " + bop.getBeginLine() + ")");
            }
        }
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 54 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method canonicalizeMatrixMultScalarAdd.

/**
 * Rewrite to canonicalize all patterns like U%*%V+eps, eps+U%*%V, and
 * U%*%V-eps into the common representation U%*%V+s which simplifies
 * subsequent rewrites (e.g., wdivmm or wcemm with epsilon).
 *
 * @param hi high-level operator
 * @return high-level operator
 */
private static Hop canonicalizeMatrixMultScalarAdd(Hop hi) {
    // pattern: binary operation (+ or -) of matrix mult and scalar
    if (hi instanceof BinaryOp) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = hi.getInput().get(0);
        Hop right = hi.getInput().get(1);
        // pattern: (eps + U%*%V) -> (U%*%V+eps)
        if (left.getDataType().isScalar() && right instanceof AggBinaryOp && bop.getOp() == OpOp2.PLUS) {
            HopRewriteUtils.removeAllChildReferences(bop);
            HopRewriteUtils.addChildReference(bop, right, 0);
            HopRewriteUtils.addChildReference(bop, left, 1);
            LOG.debug("Applied canonicalizeMatrixMultScalarAdd1 (line " + hi.getBeginLine() + ").");
        } else // pattern: (U%*%V - eps) -> (U%*%V + (-eps))
        if (right.getDataType().isScalar() && left instanceof AggBinaryOp && bop.getOp() == OpOp2.MINUS) {
            bop.setOp(OpOp2.PLUS);
            HopRewriteUtils.replaceChildReference(bop, right, HopRewriteUtils.createBinaryMinus(right), 1);
            LOG.debug("Applied canonicalizeMatrixMultScalarAdd2 (line " + hi.getBeginLine() + ").");
        }
    }
    return hi;
}
Also used : AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 55 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifyBinaryMatrixScalarOperation.

private static Hop simplifyBinaryMatrixScalarOperation(Hop parent, Hop hi, int pos) {
    if (HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR) && hi.getInput().get(0) instanceof BinaryOp && HopRewriteUtils.isBinary(hi.getInput().get(0), LOOKUP_VALID_SCALAR_BINARY)) {
        BinaryOp bin = (BinaryOp) hi.getInput().get(0);
        BinaryOp bout = null;
        // as.scalar(X*Y) -> as.scalar(X) * as.scalar(Y)
        if (bin.getInput().get(0).getDataType() == DataType.MATRIX && bin.getInput().get(1).getDataType() == DataType.MATRIX) {
            UnaryOp cast1 = HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR);
            UnaryOp cast2 = HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR);
            bout = HopRewriteUtils.createBinary(cast1, cast2, bin.getOp());
        } else // as.scalar(X*s) -> as.scalar(X) * s
        if (bin.getInput().get(0).getDataType() == DataType.MATRIX) {
            UnaryOp cast = HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR);
            bout = HopRewriteUtils.createBinary(cast, bin.getInput().get(1), bin.getOp());
        } else // as.scalar(s*X) -> s * as.scalar(X)
        if (bin.getInput().get(1).getDataType() == DataType.MATRIX) {
            UnaryOp cast = HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR);
            bout = HopRewriteUtils.createBinary(bin.getInput().get(0), cast, bin.getOp());
        }
        if (bout != null) {
            HopRewriteUtils.replaceChildReference(parent, hi, bout, pos);
            LOG.debug("Applied simplifyBinaryMatrixScalarOperation.");
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

BinaryOp (org.apache.sysml.hops.BinaryOp)58 Hop (org.apache.sysml.hops.Hop)57 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)50 LiteralOp (org.apache.sysml.hops.LiteralOp)27 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)23 UnaryOp (org.apache.sysml.hops.UnaryOp)18 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)9 IndexingOp (org.apache.sysml.hops.IndexingOp)6 ReorgOp (org.apache.sysml.hops.ReorgOp)6 TernaryOp (org.apache.sysml.hops.TernaryOp)6 ArrayList (java.util.ArrayList)5 DataOp (org.apache.sysml.hops.DataOp)5 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)5 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)5 DataGenOp (org.apache.sysml.hops.DataGenOp)4 HashMap (java.util.HashMap)3 LeftIndexingOp (org.apache.sysml.hops.LeftIndexingOp)3 CNode (org.apache.sysml.hops.codegen.cplan.CNode)3 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)3 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)3