Search in sources :

Example 76 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyColwiseAggregate.

@SuppressWarnings("unchecked")
private static Hop simplifyColwiseAggregate(Hop parent, Hop hi, int pos) {
    if (hi instanceof AggUnaryOp) {
        AggUnaryOp uhi = (AggUnaryOp) hi;
        Hop input = uhi.getInput().get(0);
        if (HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE)) {
            if (uhi.getDirection() == Direction.Col) {
                if (input.getDim1() == 1) {
                    if (uhi.getOp() == AggOp.VAR) {
                        // For the column variance aggregation, if the input is a row vector,
                        // the column variances will each be zero.
                        // Therefore, perform a rewrite from COLVAR(X) to a row vector of zeros.
                        Hop emptyRow = HopRewriteUtils.createDataGenOp(uhi, input, 0);
                        HopRewriteUtils.replaceChildReference(parent, hi, emptyRow, pos);
                        HopRewriteUtils.cleanupUnreferenced(hi, input);
                        hi = emptyRow;
                        LOG.debug("Applied simplifyColwiseAggregate for colVars");
                    } else {
                        // All other valid column aggregations over a row vector will result
                        // in the row vector itself.
                        // Therefore, remove unnecessary col aggregation for 1 row.
                        HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
                        HopRewriteUtils.cleanupUnreferenced(hi);
                        hi = input;
                        LOG.debug("Applied simplifyColwiseAggregate1");
                    }
                } else if (input.getDim2() == 1) {
                    // get old parents (before creating cast over aggregate)
                    ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone();
                    // simplify col-aggregate to full aggregate
                    uhi.setDirection(Direction.RowCol);
                    uhi.setDataType(DataType.SCALAR);
                    // create cast to keep same output datatype
                    UnaryOp cast = HopRewriteUtils.createUnary(uhi, OpOp1.CAST_AS_MATRIX);
                    // rehang cast under all parents
                    for (Hop p : parents) {
                        int ix = HopRewriteUtils.getChildReferencePos(p, hi);
                        HopRewriteUtils.replaceChildReference(p, hi, cast, ix);
                    }
                    hi = cast;
                    LOG.debug("Applied simplifyColwiseAggregate2");
                }
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList)

Example 77 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyScalarMatrixMult.

private static Hop simplifyScalarMatrixMult(Hop parent, Hop hi, int pos) {
    if (// X%*%Y
    HopRewriteUtils.isMatrixMultiply(hi)) {
        Hop left = hi.getInput().get(0);
        Hop right = hi.getInput().get(1);
        // y %*% X -> as.scalar(y) * X
        if (// scalar left
        HopRewriteUtils.isDimsKnown(left) && left.getDim1() == 1 && left.getDim2() == 1) {
            UnaryOp cast = HopRewriteUtils.createUnary(left, OpOp1.CAST_AS_SCALAR);
            BinaryOp mult = HopRewriteUtils.createBinary(cast, right, OpOp2.MULT);
            // add mult to parent
            HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
            HopRewriteUtils.cleanupUnreferenced(hi);
            hi = mult;
            LOG.debug("Applied simplifyScalarMatrixMult1");
        } else // X %*% y -> X * as.scalar(y)
        if (// scalar right
        HopRewriteUtils.isDimsKnown(right) && right.getDim1() == 1 && right.getDim2() == 1) {
            UnaryOp cast = HopRewriteUtils.createUnary(right, OpOp1.CAST_AS_SCALAR);
            BinaryOp mult = HopRewriteUtils.createBinary(cast, left, OpOp2.MULT);
            // add mult to parent
            HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
            HopRewriteUtils.cleanupUnreferenced(hi);
            hi = mult;
            LOG.debug("Applied simplifyScalarMatrixMult2");
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 78 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method removeUnnecessaryCumulativeOp.

private static Hop removeUnnecessaryCumulativeOp(Hop parent, Hop hi, int pos) {
    if (hi instanceof UnaryOp && ((UnaryOp) hi).isCumulativeUnaryOperation()) {
        // input matrix
        Hop input = hi.getInput().get(0);
        if (// dims input known
        HopRewriteUtils.isDimsKnown(input) && // 1 row
        input.getDim1() == 1) {
            OpOp1 op = ((UnaryOp) hi).getOp();
            // remove unnecessary unary cumsum operator
            HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
            hi = input;
            LOG.debug("Applied removeUnnecessaryCumulativeOp: " + op);
        }
    }
    return hi;
}
Also used : OpOp1(org.apache.sysml.hops.Hop.OpOp1) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop)

Example 79 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method pushdownBinaryOperationOnDiag.

@SuppressWarnings("unchecked")
private static Hop pushdownBinaryOperationOnDiag(Hop parent, Hop hi, int pos) {
    // (2) in order to make the binary operation more efficient (dense vector vs sparse matrix)
    if (HopRewriteUtils.isBinary(hi, OpOp2.MULT)) {
        Hop left = hi.getInput().get(0);
        Hop right = hi.getInput().get(1);
        boolean applyLeft = false;
        boolean applyRight = false;
        // left input is diag
        if (left instanceof ReorgOp && ((ReorgOp) left).getOp() == ReOrgOp.DIAG && // binary op only parent
        left.getParent().size() == 1 && // col vector
        left.getInput().get(0).getDim2() == 1 && right.getDataType() == DataType.SCALAR) {
            applyLeft = true;
        } else if (right instanceof ReorgOp && ((ReorgOp) right).getOp() == ReOrgOp.DIAG && // binary op only parent
        right.getParent().size() == 1 && // col vector
        right.getInput().get(0).getDim2() == 1 && left.getDataType() == DataType.SCALAR) {
            applyRight = true;
        }
        // perform actual rewrite
        if (applyLeft || applyRight) {
            // remove all parent links to binary op (since we want to reorder
            // we cannot just look at the current parent)
            ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone();
            ArrayList<Integer> parentspos = new ArrayList<>();
            for (Hop lparent : parents) {
                int lpos = HopRewriteUtils.getChildReferencePos(lparent, hi);
                HopRewriteUtils.removeChildReferenceByPos(lparent, hi, lpos);
                parentspos.add(lpos);
            }
            // rewire binop-diag-input into diag-binop-input
            if (applyLeft) {
                Hop input = left.getInput().get(0);
                HopRewriteUtils.removeChildReferenceByPos(hi, left, 0);
                HopRewriteUtils.removeChildReferenceByPos(left, input, 0);
                HopRewriteUtils.addChildReference(left, hi, 0);
                HopRewriteUtils.addChildReference(hi, input, 0);
                hi.refreshSizeInformation();
                hi = left;
            } else if (applyRight) {
                Hop input = right.getInput().get(0);
                HopRewriteUtils.removeChildReferenceByPos(hi, right, 1);
                HopRewriteUtils.removeChildReferenceByPos(right, input, 0);
                HopRewriteUtils.addChildReference(right, hi, 0);
                HopRewriteUtils.addChildReference(hi, input, 1);
                hi.refreshSizeInformation();
                hi = right;
            }
            // relink all parents to the diag operation
            for (int i = 0; i < parents.size(); i++) {
                Hop lparent = parents.get(i);
                int lpos = parentspos.get(i);
                HopRewriteUtils.addChildReference(lparent, hi, lpos);
            }
            LOG.debug("Applied pushdownBinaryOperationOnDiag.");
        }
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp) ArrayList(java.util.ArrayList)

Example 80 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyEmptyReorgOperation.

private static Hop simplifyEmptyReorgOperation(Hop parent, Hop hi, int pos) {
    if (hi instanceof ReorgOp) {
        ReorgOp rhi = (ReorgOp) hi;
        Hop input = rhi.getInput().get(0);
        if (// empty input
        HopRewriteUtils.isEmpty(input)) {
            // reorg-operation-specific rewrite
            Hop hnew = null;
            if (rhi.getOp() == ReOrgOp.TRANSPOSE)
                hnew = HopRewriteUtils.createDataGenOp(input, true, input, true, 0);
            else if (rhi.getOp() == ReOrgOp.REV)
                hnew = HopRewriteUtils.createDataGenOp(input, 0);
            else if (rhi.getOp() == ReOrgOp.DIAG) {
                if (HopRewriteUtils.isDimsKnown(input)) {
                    if (// diagv2m
                    input.getDim2() == 1)
                        hnew = HopRewriteUtils.createDataGenOp(input, false, input, true, 0);
                    else
                        // diagm2v
                        hnew = HopRewriteUtils.createDataGenOpByVal(HopRewriteUtils.createValueHop(input, true), new LiteralOp(1), 0);
                }
            } else if (rhi.getOp() == ReOrgOp.RESHAPE)
                hnew = HopRewriteUtils.createDataGenOpByVal(rhi.getInput().get(1), rhi.getInput().get(2), 0);
            // modify dag if one of the above rules applied
            if (hnew != null) {
                HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
                hi = hnew;
                LOG.debug("Applied simplifyEmptyReorgOperation");
            }
        }
    }
    return hi;
}
Also used : ReorgOp(org.apache.sysml.hops.ReorgOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Aggregations

Hop (org.apache.sysml.hops.Hop)307 LiteralOp (org.apache.sysml.hops.LiteralOp)94 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)65 BinaryOp (org.apache.sysml.hops.BinaryOp)63 ArrayList (java.util.ArrayList)61 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)61 HashMap (java.util.HashMap)44 DataOp (org.apache.sysml.hops.DataOp)41 UnaryOp (org.apache.sysml.hops.UnaryOp)41 HashSet (java.util.HashSet)39 ReorgOp (org.apache.sysml.hops.ReorgOp)32 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)28 StatementBlock (org.apache.sysml.parser.StatementBlock)28 IndexingOp (org.apache.sysml.hops.IndexingOp)24 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)23 WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)23 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)22 DataGenOp (org.apache.sysml.hops.DataGenOp)21 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)21 HopsException (org.apache.sysml.hops.HopsException)18