Search in sources :

Example 71 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifySumMatrixMult.

private static Hop simplifySumMatrixMult(Hop parent, Hop hi, int pos) {
    // -- if sum not the only consumer, not applied to prevent redundancy
    if (// sum
    hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM && // A%*%B
    hi.getInput().get(0) instanceof AggBinaryOp && // not dot product
    (hi.getInput().get(0).getDim1() > 1 || hi.getInput().get(0).getDim2() > 1) && // not multiple consumers of matrix mult
    hi.getInput().get(0).getParent().size() == 1) {
        Hop hi2 = hi.getInput().get(0);
        Hop left = hi2.getInput().get(0);
        Hop right = hi2.getInput().get(1);
        // remove link from parent to matrix mult
        HopRewriteUtils.removeChildReference(hi, hi2);
        // create new operators
        Hop root = null;
        // pattern: sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), later rewritten to dot-product
        if (((AggUnaryOp) hi).getDirection() == Direction.RowCol) {
            AggUnaryOp colSum = HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col);
            ReorgOp trans = HopRewriteUtils.createTranspose(colSum);
            AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row);
            root = HopRewriteUtils.createBinary(trans, rowSum, OpOp2.MULT);
            LOG.debug("Applied simplifySumMatrixMult RC.");
        } else // colSums(A%*%B) -> colSums(A)%*%B
        if (((AggUnaryOp) hi).getDirection() == Direction.Col) {
            AggUnaryOp colSum = HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col);
            root = HopRewriteUtils.createMatrixMultiply(colSum, right);
            LOG.debug("Applied simplifySumMatrixMult C.");
        } else // rowSums(A%*%B) -> A%*%rowSums(B)
        if (((AggUnaryOp) hi).getDirection() == Direction.Row) {
            AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row);
            root = HopRewriteUtils.createMatrixMultiply(left, rowSum);
            LOG.debug("Applied simplifySumMatrixMult R.");
        }
        // rehang new subdag under current node (keep hi intact)
        HopRewriteUtils.addChildReference(hi, root, 0);
        hi.refreshSizeInformation();
        // cleanup if only consumer of intermediate
        HopRewriteUtils.cleanupUnreferenced(hi2);
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp)

Example 72 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method fuseAxpyBinaryOperationChain.

private static Hop fuseAxpyBinaryOperationChain(Hop parent, Hop hi, int pos) {
    // patterns: (a) X + s*Y -> X +* sY, (b) s*Y+X -> X +* sY, (c) X - s*Y -> X -* sY
    if (hi instanceof BinaryOp && !((BinaryOp) hi).isOuterVectorOperator() && (((BinaryOp) hi).getOp() == OpOp2.PLUS || ((BinaryOp) hi).getOp() == OpOp2.MINUS)) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        Hop ternop = null;
        // pattern (a) X + s*Y -> X +* sY
        if (bop.getOp() == OpOp2.PLUS && left.getDataType() == DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(right) && HopRewriteUtils.isEqualSize(left, right) && // single consumer s*Y
        right.getParent().size() == 1) {
            Hop smid = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 0 : 1);
            Hop mright = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 1 : 0);
            ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) smid) == 0) ? left : HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.PLUS_MULT);
            LOG.debug("Applied fuseAxpyBinaryOperationChain1. (line " + hi.getBeginLine() + ")");
        } else // pattern (b) s*Y + X -> X +* sY
        if (bop.getOp() == OpOp2.PLUS && right.getDataType() == DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(left) && HopRewriteUtils.isEqualSize(left, right) && // single consumer s*Y
        left.getParent().size() == 1) {
            Hop smid = left.getInput().get((left.getInput().get(0).getDataType() == DataType.SCALAR) ? 0 : 1);
            Hop mright = left.getInput().get((left.getInput().get(0).getDataType() == DataType.SCALAR) ? 1 : 0);
            ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) smid) == 0) ? right : HopRewriteUtils.createTernaryOp(right, smid, mright, OpOp3.PLUS_MULT);
            LOG.debug("Applied fuseAxpyBinaryOperationChain2. (line " + hi.getBeginLine() + ")");
        } else // pattern (c) X - s*Y -> X -* sY
        if (bop.getOp() == OpOp2.MINUS && left.getDataType() == DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(right) && HopRewriteUtils.isEqualSize(left, right) && // single consumer s*Y
        right.getParent().size() == 1) {
            Hop smid = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 0 : 1);
            Hop mright = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 1 : 0);
            ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) smid) == 0) ? left : HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.MINUS_MULT);
            LOG.debug("Applied fuseAxpyBinaryOperationChain3. (line " + hi.getBeginLine() + ")");
        }
        // rewire parent-child operators if rewrite applied
        if (ternop != null) {
            HopRewriteUtils.replaceChildReference(parent, hi, ternop, pos);
            hi = ternop;
        }
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 73 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method removeUnnecessaryRightIndexing.

private static Hop removeUnnecessaryRightIndexing(Hop parent, Hop hi, int pos) {
    if (HopRewriteUtils.isUnnecessaryRightIndexing(hi)) {
        // remove unnecessary right indexing
        Hop input = hi.getInput().get(0);
        HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
        HopRewriteUtils.cleanupUnreferenced(hi);
        hi = input;
        LOG.debug("Applied removeUnnecessaryRightIndexing");
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop)

Example 74 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyEmptyMatrixMult.

private static Hop simplifyEmptyMatrixMult(Hop parent, Hop hi, int pos) {
    if (// X%*%Y -> matrix(0, )
    HopRewriteUtils.isMatrixMultiply(hi)) {
        Hop left = hi.getInput().get(0);
        Hop right = hi.getInput().get(1);
        if (// one input empty
        HopRewriteUtils.isEmpty(left) || HopRewriteUtils.isEmpty(right)) {
            // create datagen and add it to parent
            Hop hnew = HopRewriteUtils.createDataGenOp(left, right, 0);
            HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
            hi = hnew;
            LOG.debug("Applied simplifyEmptyMatrixMult");
        }
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop)

Example 75 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method removeUnnecessaryLeftIndexing.

private static Hop removeUnnecessaryLeftIndexing(Hop parent, Hop hi, int pos) {
    if (// left indexing op
    hi instanceof LeftIndexingOp) {
        // rhs matrix/frame
        Hop input = hi.getInput().get(1);
        if (// equal dims
        HopRewriteUtils.isEqualSize(hi, input)) {
            // equal dims of left indexing input and output -> no need for indexing
            // remove unnecessary right indexing
            HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
            HopRewriteUtils.cleanupUnreferenced(hi);
            hi = input;
            LOG.debug("Applied removeUnnecessaryLeftIndexing");
        }
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp)

Aggregations

Hop (org.apache.sysml.hops.Hop)307 LiteralOp (org.apache.sysml.hops.LiteralOp)94 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)65 BinaryOp (org.apache.sysml.hops.BinaryOp)63 ArrayList (java.util.ArrayList)61 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)61 HashMap (java.util.HashMap)44 DataOp (org.apache.sysml.hops.DataOp)41 UnaryOp (org.apache.sysml.hops.UnaryOp)41 HashSet (java.util.HashSet)39 ReorgOp (org.apache.sysml.hops.ReorgOp)32 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)28 StatementBlock (org.apache.sysml.parser.StatementBlock)28 IndexingOp (org.apache.sysml.hops.IndexingOp)24 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)23 WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)23 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)22 DataGenOp (org.apache.sysml.hops.DataGenOp)21 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)21 HopsException (org.apache.sysml.hops.HopsException)18