Search in sources :

Example 46 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedDivMM.

private static Hop simplifyWeightedDivMM(Hop parent, Hop hi, int pos) {
    Hop hnew = null;
    boolean appliedPattern = false;
    // note: we do not rewrite t(X)%*%(w*(X%*%v)) where w and v are vectors (see mmchain ops)
    if (HopRewriteUtils.isMatrixMultiply(hi) && (hi.getInput().get(0) instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi.getInput().get(0)).getOp(), LOOKUP_VALID_WDIVMM_BINARY) || hi.getInput().get(1) instanceof BinaryOp && // not applied for vector-vector mult
    hi.getDim2() > 1 && HopRewriteUtils.isValidOp(((BinaryOp) hi.getInput().get(1)).getOp(), LOOKUP_VALID_WDIVMM_BINARY))) {
        Hop left = hi.getInput().get(0);
        Hop right = hi.getInput().get(1);
        // alternative pattern: t(U) %*% (W*(U%*%t(V)))
        if (right instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) right).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
        HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1)) && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0), true)) {
            Hop W = right.getInput().get(0);
            Hop U = right.getInput().get(1).getInput().get(0);
            Hop V = right.getInput().get(1).getInput().get(1);
            if (HopRewriteUtils.isTransposeOfItself(left, U)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = HopRewriteUtils.createTranspose(V);
                else
                    V = V.getInput().get(0);
                boolean mult = ((BinaryOp) right).getOp() == OpOp2.MULT;
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 1, mult, false);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                // add output transpose for efficient target indexing (redundant t() removed by other rewrites)
                hnew = HopRewriteUtils.createTranspose(hnew);
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM1 (line " + hi.getBeginLine() + ")");
            }
        }
        // Pattern 1e) t(U) %*% (W/(U%*%t(V) + x))
        if (!appliedPattern && // DIV
        HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[1]) && // prevent mv
        HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) && HopRewriteUtils.isBinary(right.getInput().get(1), Hop.OpOp2.PLUS) && right.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
            Hop W = right.getInput().get(0);
            Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0);
            Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1);
            Hop X = right.getInput().get(1).getInput().get(1);
            if (HopRewriteUtils.isTransposeOfItself(left, U)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = HopRewriteUtils.createTranspose(V);
                else
                    V = V.getInput().get(0);
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, X, 3, false, // 3=>DIV_LEFT_EPS
                false);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                // add output transpose for efficient target indexing (redundant t() removed by other rewrites)
                hnew = HopRewriteUtils.createTranspose(hnew);
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM1e (line " + hi.getBeginLine() + ")");
            }
        }
        // alternative pattern: (W*(U%*%t(V))) %*% V
        if (!appliedPattern && left instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) left).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
        HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1)) && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0), true)) {
            Hop W = left.getInput().get(0);
            Hop U = left.getInput().get(1).getInput().get(0);
            Hop V = left.getInput().get(1).getInput().get(1);
            if (HopRewriteUtils.isTransposeOfItself(right, V)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = right;
                else
                    V = V.getInput().get(0);
                boolean mult = ((BinaryOp) left).getOp() == OpOp2.MULT;
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 2, mult, false);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM2 (line " + hi.getBeginLine() + ")");
            }
        }
        // Pattern 2e) (W/(U%*%t(V) + x)) %*% V
        if (!appliedPattern && // DIV
        HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[1]) && // prevent mv
        HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) && HopRewriteUtils.isBinary(left.getInput().get(1), Hop.OpOp2.PLUS) && left.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
            Hop W = left.getInput().get(0);
            Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0);
            Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1);
            Hop X = left.getInput().get(1).getInput().get(1);
            if (HopRewriteUtils.isTransposeOfItself(right, V)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = right;
                else
                    V = V.getInput().get(0);
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, X, 4, false, // 4=>DIV_RIGHT_EPS
                false);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM2e (line " + hi.getBeginLine() + ")");
            }
        }
        // Pattern 3) t(U) %*% ((X!=0)*(U%*%t(V)-X))
        if (!appliedPattern && // MULT
        HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[0]) && HopRewriteUtils.isBinary(right.getInput().get(1), OpOp2.MINUS) && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
            Hop W = right.getInput().get(0);
            Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0);
            Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1);
            Hop X = right.getInput().get(1).getInput().get(1);
            if (// W-X constraint
            HopRewriteUtils.isNonZeroIndicator(W, X) && // t(U)-U constraint
            HopRewriteUtils.isTransposeOfItself(left, U)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = HopRewriteUtils.createTranspose(V);
                else
                    V = V.getInput().get(0);
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, X, U, V, new LiteralOp(-1), 1, true, true);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                // add output transpose for efficient target indexing (redundant t() removed by other rewrites)
                hnew = HopRewriteUtils.createTranspose(hnew);
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM3 (line " + hi.getBeginLine() + ")");
            }
        }
        // Pattern 4) ((X!=0)*(U%*%t(V)-X)) %*% V
        if (!appliedPattern && // MULT
        HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[0]) && HopRewriteUtils.isBinary(left.getInput().get(1), OpOp2.MINUS) && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
            Hop W = left.getInput().get(0);
            Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0);
            Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1);
            Hop X = left.getInput().get(1).getInput().get(1);
            if (// W-X constraint
            HopRewriteUtils.isNonZeroIndicator(W, X) && // V-t(V) constraint
            HopRewriteUtils.isTransposeOfItself(right, V)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = right;
                else
                    V = V.getInput().get(0);
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, X, U, V, new LiteralOp(-1), 2, true, true);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM4 (line " + hi.getBeginLine() + ")");
            }
        }
        // Pattern 5) t(U) %*% (W*(U%*%t(V)-X))
        if (!appliedPattern && // MULT
        HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[0]) && HopRewriteUtils.isBinary(right.getInput().get(1), OpOp2.MINUS) && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
            Hop W = right.getInput().get(0);
            Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0);
            Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1);
            Hop X = right.getInput().get(1).getInput().get(1);
            if (// t(U)-U constraint
            HopRewriteUtils.isTransposeOfItself(left, U)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = HopRewriteUtils.createTranspose(V);
                else
                    V = V.getInput().get(0);
                // note: x and w exchanged compared to patterns 1-4, 7
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, X, 1, true, true);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                // add output transpose for efficient target indexing (redundant t() removed by other rewrites)
                hnew = HopRewriteUtils.createTranspose(hnew);
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM5 (line " + hi.getBeginLine() + ")");
            }
        }
        // Pattern 6) (W*(U%*%t(V)-X)) %*% V
        if (!appliedPattern && // MULT
        HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[0]) && HopRewriteUtils.isBinary(left.getInput().get(1), OpOp2.MINUS) && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
            Hop W = left.getInput().get(0);
            Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0);
            Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1);
            Hop X = left.getInput().get(1).getInput().get(1);
            if (// V-t(V) constraint
            HopRewriteUtils.isTransposeOfItself(right, V)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = right;
                else
                    V = V.getInput().get(0);
                // note: x and w exchanged compared to patterns 1-4, 7
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, X, 2, true, true);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM6 (line " + hi.getBeginLine() + ")");
            }
        }
    }
    // Pattern 7) (W*(U%*%t(V)))
    if (!appliedPattern && // MULT
    HopRewriteUtils.isBinary(hi, LOOKUP_VALID_WDIVMM_BINARY[0]) && // prevent mv
    HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && // not applied for vector-vector mult
    hi.getDim2() > 1 && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && HopRewriteUtils.isOuterProductLikeMM(hi.getInput().get(1)) && // no mmchain
    (((AggBinaryOp) hi.getInput().get(1)).checkMapMultChain() == ChainType.NONE || hi.getInput().get(1).getInput().get(1).getDim2() > 1) && // BLOCKSIZE CONSTRAINT
    HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0), true)) {
        Hop W = hi.getInput().get(0);
        Hop U = hi.getInput().get(1).getInput().get(0);
        Hop V = hi.getInput().get(1).getInput().get(1);
        // W is sparse and U/V unknown or dense; or if U/V are dense
        if ((HopRewriteUtils.isSparse(W) && !HopRewriteUtils.isSparse(U) && !HopRewriteUtils.isSparse(V)) || (HopRewriteUtils.isDense(U) && HopRewriteUtils.isDense(V))) {
            V = !HopRewriteUtils.isTransposeOperation(V) ? HopRewriteUtils.createTranspose(V) : V.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 0, true, false);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedDivMM7 (line " + hi.getBeginLine() + ")");
        }
    }
    // relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 47 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyScalarMVBinaryOperation.

private static Hop simplifyScalarMVBinaryOperation(Hop hi) {
    if (// e.g., X * s
    hi instanceof BinaryOp && ((BinaryOp) hi).supportsMatrixScalarOperations() && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(1).getDataType() == DataType.MATRIX) {
        Hop right = hi.getInput().get(1);
        // X * s -> X * as.scalar(s)
        if (// scalar right
        HopRewriteUtils.isDimsKnown(right) && right.getDim1() == 1 && right.getDim2() == 1) {
            // remove link to right child and introduce cast
            UnaryOp cast = HopRewriteUtils.createUnary(right, OpOp1.CAST_AS_SCALAR);
            HopRewriteUtils.replaceChildReference(hi, right, cast, 1);
            LOG.debug("Applied simplifyScalarMVBinaryOperation.");
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 48 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method pushdownSumOnAdditiveBinary.

/**
 * patterns: sum(A+B)->sum(A)+sum(B); sum(A-B)->sum(A)-sum(B)
 *
 * @param parent the parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos) {
    // all patterns headed by full sum over binary operation
    if (// full sum root over binaryop
    hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && ((AggUnaryOp) hi).getOp() == AggOp.SUM && hi.getInput().get(0) instanceof BinaryOp && // single parent
    hi.getInput().get(0).getParent().size() == 1) {
        BinaryOp bop = (BinaryOp) hi.getInput().get(0);
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        if (// dims(A) == dims(B)
        HopRewriteUtils.isEqualSize(left, right) && left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX) {
            OpOp2 applyOp = (// pattern a: sum(A+B)->sum(A)+sum(B)
            bop.getOp() == OpOp2.PLUS || // pattern b: sum(A-B)->sum(A)-sum(B)
            bop.getOp() == OpOp2.MINUS) ? bop.getOp() : null;
            if (applyOp != null) {
                // create new subdag sum(A) bop sum(B)
                AggUnaryOp sum1 = HopRewriteUtils.createSum(left);
                AggUnaryOp sum2 = HopRewriteUtils.createSum(right);
                BinaryOp newBin = HopRewriteUtils.createBinary(sum1, sum2, applyOp);
                // rewire new subdag
                HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
                HopRewriteUtils.cleanupUnreferenced(hi, bop);
                hi = newBin;
                LOG.debug("Applied pushdownSumOnAdditiveBinary.");
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop) OpOp2(org.apache.sysml.hops.Hop.OpOp2) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 49 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifyTransposeAggBinBinaryChains.

/**
 * Patterns: t(t(A)%*%t(B)+C) -> B%*%A+t(C)
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop simplifyTransposeAggBinBinaryChains(Hop parent, Hop hi, int pos) {
    if (HopRewriteUtils.isTransposeOperation(hi) && // basic binary
    hi.getInput().get(0) instanceof BinaryOp && ((BinaryOp) hi.getInput().get(0)).supportsMatrixScalarOperations()) {
        Hop left = hi.getInput().get(0).getInput().get(0);
        Hop C = hi.getInput().get(0).getInput().get(1);
        // check matrix mult and both inputs transposes w/ single consumer
        if (left instanceof AggBinaryOp && C.getDataType().isMatrix() && HopRewriteUtils.isTransposeOperation(left.getInput().get(0)) && left.getInput().get(0).getParent().size() == 1 && HopRewriteUtils.isTransposeOperation(left.getInput().get(1)) && left.getInput().get(1).getParent().size() == 1) {
            Hop A = left.getInput().get(0).getInput().get(0);
            Hop B = left.getInput().get(1).getInput().get(0);
            AggBinaryOp abop = HopRewriteUtils.createMatrixMultiply(B, A);
            ReorgOp rop = HopRewriteUtils.createTranspose(C);
            BinaryOp bop = HopRewriteUtils.createBinary(abop, rop, OpOp2.PLUS);
            HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
            hi = bop;
            LOG.debug("Applied simplifyTransposeAggBinBinaryChains (line " + hi.getBeginLine() + ").");
        }
    }
    return hi;
}
Also used : AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 50 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifyBushyBinaryOperation.

/**
 * (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
 * (X+(Y+(Z%*%v))) -> (X+Y)+(Z%*%v)
 *
 * Note: Restriction ba() at leaf and root instead of data at leaf to not reorganize too
 * eagerly, which would loose additional rewrite potential. This rewrite has two goals
 * (1) enable XtwXv, and increase piggybacking potential by creating bushy trees.
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop simplifyBushyBinaryOperation(Hop parent, Hop hi, int pos) {
    if (hi instanceof BinaryOp && parent instanceof AggBinaryOp) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        OpOp2 op = bop.getOp();
        if (left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX && HopRewriteUtils.isValidOp(op, LOOKUP_VALID_ASSOCIATIVE_BINARY)) {
            boolean applied = false;
            if (right instanceof BinaryOp) {
                BinaryOp bop2 = (BinaryOp) right;
                Hop left2 = bop2.getInput().get(0);
                Hop right2 = bop2.getInput().get(1);
                OpOp2 op2 = bop2.getOp();
                if (op == op2 && right2.getDataType() == DataType.MATRIX && (right2 instanceof AggBinaryOp)) {
                    // (X*(Y*op()) -> (X*Y)*op()
                    BinaryOp bop3 = HopRewriteUtils.createBinary(left, left2, op);
                    BinaryOp bop4 = HopRewriteUtils.createBinary(bop3, right2, op);
                    HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, bop2);
                    hi = bop4;
                    applied = true;
                    LOG.debug("Applied simplifyBushyBinaryOperation1");
                }
            }
            if (!applied && left instanceof BinaryOp) {
                BinaryOp bop2 = (BinaryOp) left;
                Hop left2 = bop2.getInput().get(0);
                Hop right2 = bop2.getInput().get(1);
                OpOp2 op2 = bop2.getOp();
                if (op == op2 && left2.getDataType() == DataType.MATRIX && (left2 instanceof AggBinaryOp) && // X not vector, or Y vector
                (right2.getDim2() > 1 || right.getDim2() == 1) && // X not vector, or Y vector
                (right2.getDim1() > 1 || right.getDim1() == 1)) {
                    // ((op()*X)*Y) -> op()*(X*Y)
                    BinaryOp bop3 = HopRewriteUtils.createBinary(right2, right, op);
                    BinaryOp bop4 = HopRewriteUtils.createBinary(left2, bop3, op);
                    HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, bop2);
                    hi = bop4;
                    LOG.debug("Applied simplifyBushyBinaryOperation2");
                }
            }
        }
    }
    return hi;
}
Also used : AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) OpOp2(org.apache.sysml.hops.Hop.OpOp2) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

BinaryOp (org.apache.sysml.hops.BinaryOp)58 Hop (org.apache.sysml.hops.Hop)57 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)50 LiteralOp (org.apache.sysml.hops.LiteralOp)27 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)23 UnaryOp (org.apache.sysml.hops.UnaryOp)18 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)9 IndexingOp (org.apache.sysml.hops.IndexingOp)6 ReorgOp (org.apache.sysml.hops.ReorgOp)6 TernaryOp (org.apache.sysml.hops.TernaryOp)6 ArrayList (java.util.ArrayList)5 DataOp (org.apache.sysml.hops.DataOp)5 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)5 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)5 DataGenOp (org.apache.sysml.hops.DataGenOp)4 HashMap (java.util.HashMap)3 LeftIndexingOp (org.apache.sysml.hops.LeftIndexingOp)3 CNode (org.apache.sysml.hops.codegen.cplan.CNode)3 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)3 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)3