Search in sources :

Example 36 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedDivMM.

private static Hop simplifyWeightedDivMM(Hop parent, Hop hi, int pos) {
    Hop hnew = null;
    boolean appliedPattern = false;
    // note: we do not rewrite t(X)%*%(w*(X%*%v)) where w and v are vectors (see mmchain ops)
    if (HopRewriteUtils.isMatrixMultiply(hi) && (hi.getInput().get(0) instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi.getInput().get(0)).getOp(), LOOKUP_VALID_WDIVMM_BINARY) || hi.getInput().get(1) instanceof BinaryOp && // not applied for vector-vector mult
    hi.getDim2() > 1 && HopRewriteUtils.isValidOp(((BinaryOp) hi.getInput().get(1)).getOp(), LOOKUP_VALID_WDIVMM_BINARY))) {
        Hop left = hi.getInput().get(0);
        Hop right = hi.getInput().get(1);
        // alternative pattern: t(U) %*% (W*(U%*%t(V)))
        if (right instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) right).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
        HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1)) && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0), true)) {
            Hop W = right.getInput().get(0);
            Hop U = right.getInput().get(1).getInput().get(0);
            Hop V = right.getInput().get(1).getInput().get(1);
            if (HopRewriteUtils.isTransposeOfItself(left, U)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = HopRewriteUtils.createTranspose(V);
                else
                    V = V.getInput().get(0);
                boolean mult = ((BinaryOp) right).getOp() == OpOp2.MULT;
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 1, mult, false);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                // add output transpose for efficient target indexing (redundant t() removed by other rewrites)
                hnew = HopRewriteUtils.createTranspose(hnew);
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM1 (line " + hi.getBeginLine() + ")");
            }
        }
        // Pattern 1e) t(U) %*% (W/(U%*%t(V) + x))
        if (!appliedPattern && // DIV
        HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[1]) && // prevent mv
        HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) && HopRewriteUtils.isBinary(right.getInput().get(1), Hop.OpOp2.PLUS) && right.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
            Hop W = right.getInput().get(0);
            Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0);
            Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1);
            Hop X = right.getInput().get(1).getInput().get(1);
            if (HopRewriteUtils.isTransposeOfItself(left, U)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = HopRewriteUtils.createTranspose(V);
                else
                    V = V.getInput().get(0);
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, X, 3, false, // 3=>DIV_LEFT_EPS
                false);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                // add output transpose for efficient target indexing (redundant t() removed by other rewrites)
                hnew = HopRewriteUtils.createTranspose(hnew);
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM1e (line " + hi.getBeginLine() + ")");
            }
        }
        // alternative pattern: (W*(U%*%t(V))) %*% V
        if (!appliedPattern && left instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) left).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
        HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1)) && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0), true)) {
            Hop W = left.getInput().get(0);
            Hop U = left.getInput().get(1).getInput().get(0);
            Hop V = left.getInput().get(1).getInput().get(1);
            if (HopRewriteUtils.isTransposeOfItself(right, V)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = right;
                else
                    V = V.getInput().get(0);
                boolean mult = ((BinaryOp) left).getOp() == OpOp2.MULT;
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 2, mult, false);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM2 (line " + hi.getBeginLine() + ")");
            }
        }
        // Pattern 2e) (W/(U%*%t(V) + x)) %*% V
        if (!appliedPattern && // DIV
        HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[1]) && // prevent mv
        HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) && HopRewriteUtils.isBinary(left.getInput().get(1), Hop.OpOp2.PLUS) && left.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
            Hop W = left.getInput().get(0);
            Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0);
            Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1);
            Hop X = left.getInput().get(1).getInput().get(1);
            if (HopRewriteUtils.isTransposeOfItself(right, V)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = right;
                else
                    V = V.getInput().get(0);
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, X, 4, false, // 4=>DIV_RIGHT_EPS
                false);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM2e (line " + hi.getBeginLine() + ")");
            }
        }
        // Pattern 3) t(U) %*% ((X!=0)*(U%*%t(V)-X))
        if (!appliedPattern && // MULT
        HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[0]) && HopRewriteUtils.isBinary(right.getInput().get(1), OpOp2.MINUS) && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
            Hop W = right.getInput().get(0);
            Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0);
            Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1);
            Hop X = right.getInput().get(1).getInput().get(1);
            if (// W-X constraint
            HopRewriteUtils.isNonZeroIndicator(W, X) && // t(U)-U constraint
            HopRewriteUtils.isTransposeOfItself(left, U)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = HopRewriteUtils.createTranspose(V);
                else
                    V = V.getInput().get(0);
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, X, U, V, new LiteralOp(-1), 1, true, true);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                // add output transpose for efficient target indexing (redundant t() removed by other rewrites)
                hnew = HopRewriteUtils.createTranspose(hnew);
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM3 (line " + hi.getBeginLine() + ")");
            }
        }
        // Pattern 4) ((X!=0)*(U%*%t(V)-X)) %*% V
        if (!appliedPattern && // MULT
        HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[0]) && HopRewriteUtils.isBinary(left.getInput().get(1), OpOp2.MINUS) && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
            Hop W = left.getInput().get(0);
            Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0);
            Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1);
            Hop X = left.getInput().get(1).getInput().get(1);
            if (// W-X constraint
            HopRewriteUtils.isNonZeroIndicator(W, X) && // V-t(V) constraint
            HopRewriteUtils.isTransposeOfItself(right, V)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = right;
                else
                    V = V.getInput().get(0);
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, X, U, V, new LiteralOp(-1), 2, true, true);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM4 (line " + hi.getBeginLine() + ")");
            }
        }
        // Pattern 5) t(U) %*% (W*(U%*%t(V)-X))
        if (!appliedPattern && // MULT
        HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[0]) && HopRewriteUtils.isBinary(right.getInput().get(1), OpOp2.MINUS) && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
            Hop W = right.getInput().get(0);
            Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0);
            Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1);
            Hop X = right.getInput().get(1).getInput().get(1);
            if (// t(U)-U constraint
            HopRewriteUtils.isTransposeOfItself(left, U)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = HopRewriteUtils.createTranspose(V);
                else
                    V = V.getInput().get(0);
                // note: x and w exchanged compared to patterns 1-4, 7
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, X, 1, true, true);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                // add output transpose for efficient target indexing (redundant t() removed by other rewrites)
                hnew = HopRewriteUtils.createTranspose(hnew);
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM5 (line " + hi.getBeginLine() + ")");
            }
        }
        // Pattern 6) (W*(U%*%t(V)-X)) %*% V
        if (!appliedPattern && // MULT
        HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[0]) && HopRewriteUtils.isBinary(left.getInput().get(1), OpOp2.MINUS) && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
            Hop W = left.getInput().get(0);
            Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0);
            Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1);
            Hop X = left.getInput().get(1).getInput().get(1);
            if (// V-t(V) constraint
            HopRewriteUtils.isTransposeOfItself(right, V)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = right;
                else
                    V = V.getInput().get(0);
                // note: x and w exchanged compared to patterns 1-4, 7
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, X, 2, true, true);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM6 (line " + hi.getBeginLine() + ")");
            }
        }
    }
    // Pattern 7) (W*(U%*%t(V)))
    if (!appliedPattern && // MULT
    HopRewriteUtils.isBinary(hi, LOOKUP_VALID_WDIVMM_BINARY[0]) && // prevent mv
    HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && // not applied for vector-vector mult
    hi.getDim2() > 1 && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && HopRewriteUtils.isOuterProductLikeMM(hi.getInput().get(1)) && // no mmchain
    (((AggBinaryOp) hi.getInput().get(1)).checkMapMultChain() == ChainType.NONE || hi.getInput().get(1).getInput().get(1).getDim2() > 1) && // BLOCKSIZE CONSTRAINT
    HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0), true)) {
        Hop W = hi.getInput().get(0);
        Hop U = hi.getInput().get(1).getInput().get(0);
        Hop V = hi.getInput().get(1).getInput().get(1);
        // W is sparse and U/V unknown or dense; or if U/V are dense
        if ((HopRewriteUtils.isSparse(W) && !HopRewriteUtils.isSparse(U) && !HopRewriteUtils.isSparse(V)) || (HopRewriteUtils.isDense(U) && HopRewriteUtils.isDense(V))) {
            V = !HopRewriteUtils.isTransposeOperation(V) ? HopRewriteUtils.createTranspose(V) : V.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 0, true, false);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedDivMM7 (line " + hi.getBeginLine() + ")");
        }
    }
    // relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 37 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyRowSumsMVMult.

private static Hop simplifyRowSumsMVMult(Hop parent, Hop hi, int pos) {
    // removed by other rewrite if unnecessary, i.e., if Y==t(Z)
    if (hi instanceof AggUnaryOp) {
        AggUnaryOp uhi = (AggUnaryOp) hi;
        Hop input = uhi.getInput().get(0);
        if (// rowsums
        uhi.getOp() == AggOp.SUM && uhi.getDirection() == Direction.Row && // b(*)
        HopRewriteUtils.isBinary(input, OpOp2.MULT)) {
            Hop left = input.getInput().get(0);
            Hop right = input.getInput().get(1);
            if (left.getDim1() > 1 && left.getDim2() > 1 && right.getDim1() == 1 && // MV (row vector)
            right.getDim2() > 1) {
                // create new operators
                ReorgOp trans = HopRewriteUtils.createTranspose(right);
                AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(left, trans);
                // relink new child
                HopRewriteUtils.replaceChildReference(parent, hi, mmult, pos);
                HopRewriteUtils.cleanupUnreferenced(hi, input);
                hi = mmult;
                LOG.debug("Applied simplifyRowSumsMVMult");
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp)

Example 38 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifyTransposeAggBinBinaryChains.

/**
 * Patterns: t(t(A)%*%t(B)+C) -> B%*%A+t(C)
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop simplifyTransposeAggBinBinaryChains(Hop parent, Hop hi, int pos) {
    if (HopRewriteUtils.isTransposeOperation(hi) && // basic binary
    hi.getInput().get(0) instanceof BinaryOp && ((BinaryOp) hi.getInput().get(0)).supportsMatrixScalarOperations()) {
        Hop left = hi.getInput().get(0).getInput().get(0);
        Hop C = hi.getInput().get(0).getInput().get(1);
        // check matrix mult and both inputs transposes w/ single consumer
        if (left instanceof AggBinaryOp && C.getDataType().isMatrix() && HopRewriteUtils.isTransposeOperation(left.getInput().get(0)) && left.getInput().get(0).getParent().size() == 1 && HopRewriteUtils.isTransposeOperation(left.getInput().get(1)) && left.getInput().get(1).getParent().size() == 1) {
            Hop A = left.getInput().get(0).getInput().get(0);
            Hop B = left.getInput().get(1).getInput().get(0);
            AggBinaryOp abop = HopRewriteUtils.createMatrixMultiply(B, A);
            ReorgOp rop = HopRewriteUtils.createTranspose(C);
            BinaryOp bop = HopRewriteUtils.createBinary(abop, rop, OpOp2.PLUS);
            HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
            hi = bop;
            LOG.debug("Applied simplifyTransposeAggBinBinaryChains (line " + hi.getBeginLine() + ").");
        }
    }
    return hi;
}
Also used : AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 39 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifyBushyBinaryOperation.

/**
 * (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
 * (X+(Y+(Z%*%v))) -> (X+Y)+(Z%*%v)
 *
 * Note: Restriction ba() at leaf and root instead of data at leaf to not reorganize too
 * eagerly, which would loose additional rewrite potential. This rewrite has two goals
 * (1) enable XtwXv, and increase piggybacking potential by creating bushy trees.
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop simplifyBushyBinaryOperation(Hop parent, Hop hi, int pos) {
    if (hi instanceof BinaryOp && parent instanceof AggBinaryOp) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        OpOp2 op = bop.getOp();
        if (left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX && HopRewriteUtils.isValidOp(op, LOOKUP_VALID_ASSOCIATIVE_BINARY)) {
            boolean applied = false;
            if (right instanceof BinaryOp) {
                BinaryOp bop2 = (BinaryOp) right;
                Hop left2 = bop2.getInput().get(0);
                Hop right2 = bop2.getInput().get(1);
                OpOp2 op2 = bop2.getOp();
                if (op == op2 && right2.getDataType() == DataType.MATRIX && (right2 instanceof AggBinaryOp)) {
                    // (X*(Y*op()) -> (X*Y)*op()
                    BinaryOp bop3 = HopRewriteUtils.createBinary(left, left2, op);
                    BinaryOp bop4 = HopRewriteUtils.createBinary(bop3, right2, op);
                    HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, bop2);
                    hi = bop4;
                    applied = true;
                    LOG.debug("Applied simplifyBushyBinaryOperation1");
                }
            }
            if (!applied && left instanceof BinaryOp) {
                BinaryOp bop2 = (BinaryOp) left;
                Hop left2 = bop2.getInput().get(0);
                Hop right2 = bop2.getInput().get(1);
                OpOp2 op2 = bop2.getOp();
                if (op == op2 && left2.getDataType() == DataType.MATRIX && (left2 instanceof AggBinaryOp) && // X not vector, or Y vector
                (right2.getDim2() > 1 || right.getDim2() == 1) && // X not vector, or Y vector
                (right2.getDim1() > 1 || right.getDim1() == 1)) {
                    // ((op()*X)*Y) -> op()*(X*Y)
                    BinaryOp bop3 = HopRewriteUtils.createBinary(right2, right, op);
                    BinaryOp bop4 = HopRewriteUtils.createBinary(left2, bop3, op);
                    HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, bop2);
                    hi = bop4;
                    LOG.debug("Applied simplifyBushyBinaryOperation2");
                }
            }
        }
    }
    return hi;
}
Also used : AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) OpOp2(org.apache.sysml.hops.Hop.OpOp2) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 40 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method canonicalizeMatrixMultScalarAdd.

/**
 * Rewrite to canonicalize all patterns like U%*%V+eps, eps+U%*%V, and
 * U%*%V-eps into the common representation U%*%V+s which simplifies
 * subsequent rewrites (e.g., wdivmm or wcemm with epsilon).
 *
 * @param hi high-level operator
 * @return high-level operator
 */
private static Hop canonicalizeMatrixMultScalarAdd(Hop hi) {
    // pattern: binary operation (+ or -) of matrix mult and scalar
    if (hi instanceof BinaryOp) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = hi.getInput().get(0);
        Hop right = hi.getInput().get(1);
        // pattern: (eps + U%*%V) -> (U%*%V+eps)
        if (left.getDataType().isScalar() && right instanceof AggBinaryOp && bop.getOp() == OpOp2.PLUS) {
            HopRewriteUtils.removeAllChildReferences(bop);
            HopRewriteUtils.addChildReference(bop, right, 0);
            HopRewriteUtils.addChildReference(bop, left, 1);
            LOG.debug("Applied canonicalizeMatrixMultScalarAdd1 (line " + hi.getBeginLine() + ").");
        } else // pattern: (U%*%V - eps) -> (U%*%V + (-eps))
        if (right.getDataType().isScalar() && left instanceof AggBinaryOp && bop.getOp() == OpOp2.MINUS) {
            bop.setOp(OpOp2.PLUS);
            HopRewriteUtils.replaceChildReference(bop, right, HopRewriteUtils.createBinaryMinus(right), 1);
            LOG.debug("Applied canonicalizeMatrixMultScalarAdd2 (line " + hi.getBeginLine() + ").");
        }
    }
    return hi;
}
Also used : AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)56 Hop (org.apache.sysml.hops.Hop)47 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)32 BinaryOp (org.apache.sysml.hops.BinaryOp)27 LiteralOp (org.apache.sysml.hops.LiteralOp)21 ReorgOp (org.apache.sysml.hops.ReorgOp)15 UnaryOp (org.apache.sysml.hops.UnaryOp)15 TernaryOp (org.apache.sysml.hops.TernaryOp)11 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)11 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)10 IndexingOp (org.apache.sysml.hops.IndexingOp)9 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)9 CNode (org.apache.sysml.hops.codegen.cplan.CNode)8 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)6 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)6 CNodeUnary (org.apache.sysml.hops.codegen.cplan.CNodeUnary)6 DataOp (org.apache.sysml.hops.DataOp)4 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)4 CNodeTernary (org.apache.sysml.hops.codegen.cplan.CNodeTernary)4 TernaryType (org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType)4