Search in sources :

Example 41 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedCrossEntropy.

private static Hop simplifyWeightedCrossEntropy(Hop parent, Hop hi, int pos) {
    Hop hnew = null;
    boolean appliedPattern = false;
    if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // pattern rooted by sum()
    ((AggUnaryOp) hi).getOp() == AggOp.SUM && // pattern subrooted by binary op
    hi.getInput().get(0) instanceof BinaryOp && // not applied for vector-vector mult
    hi.getInput().get(0).getDim2() > 1) {
        BinaryOp bop = (BinaryOp) hi.getInput().get(0);
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        // Pattern 1) sum( X * log(U %*% t(V)))
        if (bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && // prevent mb
        HopRewriteUtils.isEqualSize(left, right) && HopRewriteUtils.isUnary(right, OpOp1.LOG) && // ba gurantees matrices
        right.getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0), true)) {
            Hop X = left;
            Hop U = right.getInput().get(0).getInput().get(0);
            Hop V = right.getInput().get(0).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(V))
                V = HopRewriteUtils.createTranspose(V);
            else
                V = V.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V, new LiteralOp(0.0), 0, false, false);
            hnew.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedCEMM (line " + hi.getBeginLine() + ")");
        }
        // Pattern 2) sum( X * log(U %*% t(V) + eps))
        if (!appliedPattern && bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && HopRewriteUtils.isEqualSize(left, right) && HopRewriteUtils.isUnary(right, OpOp1.LOG) && HopRewriteUtils.isBinary(right.getInput().get(0), OpOp2.PLUS) && right.getInput().get(0).getInput().get(0) instanceof AggBinaryOp && right.getInput().get(0).getInput().get(1) instanceof LiteralOp && right.getInput().get(0).getInput().get(1).getDataType() == DataType.SCALAR && HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0).getInput().get(0), true)) {
            Hop X = left;
            Hop U = right.getInput().get(0).getInput().get(0).getInput().get(0);
            Hop V = right.getInput().get(0).getInput().get(0).getInput().get(1);
            Hop eps = right.getInput().get(0).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(V))
                V = HopRewriteUtils.createTranspose(V);
            else
                V = V.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V, eps, 1, false, // 1 => BASIC_EPS
            false);
            hnew.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
            LOG.debug("Applied simplifyWeightedCEMMEps (line " + hi.getBeginLine() + ")");
        }
    }
    // relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 42 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method removeUnnecessaryOuterProduct.

private static Hop removeUnnecessaryOuterProduct(Hop parent, Hop hi, int pos) {
    if (// binary cell operation
    hi instanceof BinaryOp) {
        OpOp2 bop = ((BinaryOp) hi).getOp();
        Hop left = hi.getInput().get(0);
        Hop right = hi.getInput().get(1);
        // check for matrix-vector column replication: (A + b %*% ones) -> (A + b)
        if (// matrix mult with datagen
        HopRewriteUtils.isMatrixMultiply(right) && HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(1), 1) && // column vector for mv binary
        right.getInput().get(0).getDim2() == 1) {
            // remove unnecessary outer product
            HopRewriteUtils.replaceChildReference(hi, right, right.getInput().get(0), 1);
            HopRewriteUtils.cleanupUnreferenced(right);
            LOG.debug("Applied removeUnnecessaryOuterProduct1 (line " + right.getBeginLine() + ")");
        } else // check for matrix-vector row replication: (A + ones %*% b) -> (A + b)
        if (// matrix mult with datagen
        HopRewriteUtils.isMatrixMultiply(right) && HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(0), 1) && // row vector for mv binary
        right.getInput().get(1).getDim1() == 1) {
            // remove unnecessary outer product
            HopRewriteUtils.replaceChildReference(hi, right, right.getInput().get(1), 1);
            HopRewriteUtils.cleanupUnreferenced(right);
            LOG.debug("Applied removeUnnecessaryOuterProduct2 (line " + right.getBeginLine() + ")");
        } else // check for vector-vector column replication: (a %*% ones) == b) -> outer(a, b, "==")
        if (HopRewriteUtils.isValidOuterBinaryOp(bop) && HopRewriteUtils.isMatrixMultiply(left) && HopRewriteUtils.isDataGenOpWithConstantValue(left.getInput().get(1), 1) && (// outer product
        left.getInput().get(0).getDim2() == 1 || left.getInput().get(1).getDim1() == 1) && left.getDim1() != 1 && // outer vector binary
        right.getDim1() == 1) {
            Hop hnew = HopRewriteUtils.createBinary(left.getInput().get(0), right, bop, true);
            HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
            HopRewriteUtils.cleanupUnreferenced(hi);
            hi = hnew;
            LOG.debug("Applied removeUnnecessaryOuterProduct3 (line " + right.getBeginLine() + ")");
        }
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop) OpOp2(org.apache.sysml.hops.Hop.OpOp2) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 43 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyEmptyBinaryOperation.

private static Hop simplifyEmptyBinaryOperation(Hop parent, Hop hi, int pos) {
    if (// b(?) X Y
    hi instanceof BinaryOp) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = hi.getInput().get(0);
        Hop right = hi.getInput().get(1);
        if (left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX) {
            Hop hnew = null;
            // NOTE: these rewrites of binary cell operations need to be aware that right is
            // potentially a vector but the result is of the size of left
            boolean notBinaryMV = HopRewriteUtils.isNotMatrixVectorBinaryOperation(bop);
            switch(bop.getOp()) {
                // X * Y -> matrix(0,nrow(X),ncol(X));
                case MULT:
                    {
                        if (// empty left and size known
                        HopRewriteUtils.isEmpty(left))
                            hnew = HopRewriteUtils.createDataGenOp(left, left, 0);
                        else if (// empty right and right not a vector
                        HopRewriteUtils.isEmpty(right) && right.getDim1() > 1 && right.getDim2() > 1) {
                            hnew = HopRewriteUtils.createDataGenOp(right, right, 0);
                        } else if (// empty right and right potentially a vector
                        HopRewriteUtils.isEmpty(right))
                            hnew = HopRewriteUtils.createDataGenOp(left, left, 0);
                        break;
                    }
                case PLUS:
                    {
                        if (// empty left/right and size known
                        HopRewriteUtils.isEmpty(left) && HopRewriteUtils.isEmpty(right))
                            hnew = HopRewriteUtils.createDataGenOp(left, left, 0);
                        else if (// empty left
                        HopRewriteUtils.isEmpty(left) && notBinaryMV)
                            hnew = right;
                        else if (// empty right
                        HopRewriteUtils.isEmpty(right))
                            hnew = left;
                        break;
                    }
                case MINUS:
                    {
                        if (HopRewriteUtils.isEmpty(left) && notBinaryMV) {
                            // empty left
                            HopRewriteUtils.removeChildReference(hi, left);
                            HopRewriteUtils.addChildReference(hi, new LiteralOp(0), 0);
                            hnew = hi;
                        } else if (// empty and size known
                        HopRewriteUtils.isEmpty(right))
                            hnew = left;
                        break;
                    }
                default:
                    hnew = null;
            }
            if (hnew != null) {
                // create datagen and add it to parent
                HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
                hi = hnew;
                LOG.debug("Applied simplifyEmptyBinaryOperation");
            }
        }
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 44 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method reorderMinusMatrixMult.

/**
 * This is rewrite tries to reorder minus operators from inputs of matrix
 * multiply to its output because the output is (except for outer products)
 * usually significantly smaller. Furthermore, this rewrite is a precondition
 * for the important hops-lops rewrite of transpose-matrixmult if the transpose
 * is hidden under the minus.
 *
 * NOTE: in this rewrite we need to modify the links to all parents because we
 * remove existing links of subdags and hence affect all consumers.
 *
 * @param parent the parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
@SuppressWarnings("unchecked")
private static Hop reorderMinusMatrixMult(Hop parent, Hop hi, int pos) {
    if (// X%*%Y
    HopRewriteUtils.isMatrixMultiply(hi)) {
        Hop hileft = hi.getInput().get(0);
        Hop hiright = hi.getInput().get(1);
        if (// X=-Z
        HopRewriteUtils.isBinary(hileft, OpOp2.MINUS) && hileft.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) hileft.getInput().get(0)) == 0.0 && hi.dimsKnown() && // size comparison
        hileft.getInput().get(1).dimsKnown() && HopRewriteUtils.compareSize(hi, hileft.getInput().get(1)) < 0) {
            Hop hi2 = hileft.getInput().get(1);
            // remove link from matrixmult to minus
            HopRewriteUtils.removeChildReference(hi, hileft);
            // get old parents (before creating minus over matrix mult)
            ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone();
            // create new operators
            BinaryOp minus = HopRewriteUtils.createBinary(new LiteralOp(0), hi, OpOp2.MINUS);
            // rehang minus under all parents
            for (Hop p : parents) {
                int ix = HopRewriteUtils.getChildReferencePos(p, hi);
                HopRewriteUtils.removeChildReference(p, hi);
                HopRewriteUtils.addChildReference(p, minus, ix);
            }
            // rehang child of minus under matrix mult
            HopRewriteUtils.addChildReference(hi, hi2, 0);
            // cleanup if only consumer of minus
            HopRewriteUtils.cleanupUnreferenced(hileft);
            hi = minus;
            LOG.debug("Applied reorderMinusMatrixMult (line " + hi.getBeginLine() + ").");
        } else if (// X=-Z
        HopRewriteUtils.isBinary(hiright, OpOp2.MINUS) && hiright.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) hiright.getInput().get(0)) == 0.0 && hi.dimsKnown() && // size comparison
        hiright.getInput().get(1).dimsKnown() && HopRewriteUtils.compareSize(hi, hiright.getInput().get(1)) < 0) {
            Hop hi2 = hiright.getInput().get(1);
            // remove link from matrixmult to minus
            HopRewriteUtils.removeChildReference(hi, hiright);
            // get old parents (before creating minus over matrix mult)
            ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone();
            // create new operators
            BinaryOp minus = HopRewriteUtils.createBinary(new LiteralOp(0), hi, OpOp2.MINUS);
            // rehang minus under all parents
            for (Hop p : parents) {
                int ix = HopRewriteUtils.getChildReferencePos(p, hi);
                HopRewriteUtils.removeChildReference(p, hi);
                HopRewriteUtils.addChildReference(p, minus, ix);
            }
            // rehang child of minus under matrix mult
            HopRewriteUtils.addChildReference(hi, hi2, 1);
            // cleanup if only consumer of minus
            HopRewriteUtils.cleanupUnreferenced(hiright);
            hi = minus;
            LOG.debug("Applied reorderMinusMatrixMult (line " + hi.getBeginLine() + ").");
        }
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 45 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedUnaryMM.

private static Hop simplifyWeightedUnaryMM(Hop parent, Hop hi, int pos) {
    Hop hnew = null;
    boolean appliedPattern = false;
    // Pattern 1) (W*uop(U%*%t(V)))
    if (hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
    HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && // not applied for vector-vector mult
    hi.getDim2() > 1 && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && hi.getInput().get(1) instanceof UnaryOp && HopRewriteUtils.isValidOp(((UnaryOp) hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_UNARY) && hi.getInput().get(1).getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
    HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
        Hop W = hi.getInput().get(0);
        Hop U = hi.getInput().get(1).getInput().get(0).getInput().get(0);
        Hop V = hi.getInput().get(1).getInput().get(0).getInput().get(1);
        boolean mult = ((BinaryOp) hi).getOp() == OpOp2.MULT;
        OpOp1 op = ((UnaryOp) hi.getInput().get(1)).getOp();
        if (!HopRewriteUtils.isTransposeOperation(V))
            V = HopRewriteUtils.createTranspose(V);
        else
            V = V.getInput().get(0);
        hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, mult, op, null);
        hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
        hnew.refreshSizeInformation();
        appliedPattern = true;
        LOG.debug("Applied simplifyWeightedUnaryMM1 (line " + hi.getBeginLine() + ")");
    }
    // Pattern 2.7) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
    if (!appliedPattern && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), OpOp2.MULT) && (HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 2) || HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 2))) {
        // non-literal
        final Hop nl;
        if (hi.getInput().get(0) instanceof LiteralOp) {
            nl = hi.getInput().get(1);
        } else {
            nl = hi.getInput().get(0);
        }
        if (HopRewriteUtils.isBinary(nl, OpOp2.MULT) && // ensure no foreign parents
        nl.getParent().size() == 1 && // prevent mv
        HopRewriteUtils.isEqualSize(nl.getInput().get(0), nl.getInput().get(1)) && // not applied for vector-vector mult
        nl.getDim2() > 1 && nl.getInput().get(0).getDataType() == DataType.MATRIX && nl.getInput().get(0).getDim2() > nl.getInput().get(0).getColsInBlock() && HopRewriteUtils.isOuterProductLikeMM(nl.getInput().get(1)) && // no mmchain
        (((AggBinaryOp) nl.getInput().get(1)).checkMapMultChain() == ChainType.NONE || nl.getInput().get(1).getInput().get(1).getDim2() > 1) && HopRewriteUtils.isSingleBlock(nl.getInput().get(1).getInput().get(0), true)) {
            final Hop W = nl.getInput().get(0);
            final Hop U = nl.getInput().get(1).getInput().get(0);
            Hop V = nl.getInput().get(1).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(V))
                V = HopRewriteUtils.createTranspose(V);
            else
                V = V.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, true, null, OpOp2.MULT);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedUnaryMM2.7 (line " + hi.getBeginLine() + ")");
        }
    }
    // Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to unary ops
    if (!appliedPattern && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
    HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && // not applied for vector-vector mult
    hi.getDim2() > 1 && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && hi.getInput().get(1) instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_BINARY)) {
        Hop left = hi.getInput().get(1).getInput().get(0);
        Hop right = hi.getInput().get(1).getInput().get(1);
        Hop abop = null;
        // pattern 2a) matrix-scalar operations
        if (right.getDataType() == DataType.SCALAR && right instanceof LiteralOp && // pow2, mult2
        HopRewriteUtils.getDoubleValue((LiteralOp) right) == 2 && left instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(left.getInput().get(0), true)) {
            abop = left;
        } else // pattern 2b) scalar-matrix operations
        if (left.getDataType() == DataType.SCALAR && left instanceof LiteralOp && // mult2
        HopRewriteUtils.getDoubleValue((LiteralOp) left) == 2 && ((BinaryOp) hi.getInput().get(1)).getOp() == OpOp2.MULT && right instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(right.getInput().get(0), true)) {
            abop = right;
        }
        if (abop != null) {
            Hop W = hi.getInput().get(0);
            Hop U = abop.getInput().get(0);
            Hop V = abop.getInput().get(1);
            boolean mult = ((BinaryOp) hi).getOp() == OpOp2.MULT;
            OpOp2 op = ((BinaryOp) hi.getInput().get(1)).getOp();
            if (!HopRewriteUtils.isTransposeOperation(V))
                V = HopRewriteUtils.createTranspose(V);
            else
                V = V.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, mult, null, op);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedUnaryMM2 (line " + hi.getBeginLine() + ")");
        }
    }
    // relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) OpOp1(org.apache.sysml.hops.Hop.OpOp1) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) OpOp2(org.apache.sysml.hops.Hop.OpOp2) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

BinaryOp (org.apache.sysml.hops.BinaryOp)58 Hop (org.apache.sysml.hops.Hop)57 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)50 LiteralOp (org.apache.sysml.hops.LiteralOp)27 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)23 UnaryOp (org.apache.sysml.hops.UnaryOp)18 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)9 IndexingOp (org.apache.sysml.hops.IndexingOp)6 ReorgOp (org.apache.sysml.hops.ReorgOp)6 TernaryOp (org.apache.sysml.hops.TernaryOp)6 ArrayList (java.util.ArrayList)5 DataOp (org.apache.sysml.hops.DataOp)5 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)5 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)5 DataGenOp (org.apache.sysml.hops.DataGenOp)4 HashMap (java.util.HashMap)3 LeftIndexingOp (org.apache.sysml.hops.LeftIndexingOp)3 CNode (org.apache.sysml.hops.codegen.cplan.CNode)3 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)3 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)3