Search in sources :

Example 26 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyEmptyAggregate.

private static Hop simplifyEmptyAggregate(Hop parent, Hop hi, int pos) {
    if (hi instanceof AggUnaryOp) {
        AggUnaryOp uhi = (AggUnaryOp) hi;
        Hop input = uhi.getInput().get(0);
        // check for valid empty aggregates, except for matrices with zero rows/cols
        if (HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_EMPTY_AGGREGATE) && HopRewriteUtils.isEmpty(input) && input.getDim1() >= 1 && input.getDim2() >= 1) {
            Hop hnew = null;
            if (uhi.getDirection() == Direction.RowCol)
                hnew = new LiteralOp(0.0);
            else if (uhi.getDirection() == Direction.Col)
                // nrow(uhi)=1
                hnew = HopRewriteUtils.createDataGenOp(uhi, input, 0);
            else
                // if( uhi.getDirection() == Direction.Row )
                // ncol(uhi)=1
                hnew = HopRewriteUtils.createDataGenOp(input, uhi, 0);
            // add new child to parent input
            HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
            hi = hnew;
            LOG.debug("Applied simplifyEmptyAggregate");
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 27 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method pushdownSumOnAdditiveBinary.

/**
 * patterns: sum(A+B)->sum(A)+sum(B); sum(A-B)->sum(A)-sum(B)
 *
 * @param parent the parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos) {
    // all patterns headed by full sum over binary operation
    if (// full sum root over binaryop
    hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && ((AggUnaryOp) hi).getOp() == AggOp.SUM && hi.getInput().get(0) instanceof BinaryOp && // single parent
    hi.getInput().get(0).getParent().size() == 1) {
        BinaryOp bop = (BinaryOp) hi.getInput().get(0);
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        if (// dims(A) == dims(B)
        HopRewriteUtils.isEqualSize(left, right) && left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX) {
            OpOp2 applyOp = (// pattern a: sum(A+B)->sum(A)+sum(B)
            bop.getOp() == OpOp2.PLUS || // pattern b: sum(A-B)->sum(A)-sum(B)
            bop.getOp() == OpOp2.MINUS) ? bop.getOp() : null;
            if (applyOp != null) {
                // create new subdag sum(A) bop sum(B)
                AggUnaryOp sum1 = HopRewriteUtils.createSum(left);
                AggUnaryOp sum2 = HopRewriteUtils.createSum(right);
                BinaryOp newBin = HopRewriteUtils.createBinary(sum1, sum2, applyOp);
                // rewire new subdag
                HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
                HopRewriteUtils.cleanupUnreferenced(hi, bop);
                hi = newBin;
                LOG.debug("Applied pushdownSumOnAdditiveBinary.");
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop) OpOp2(org.apache.sysml.hops.Hop.OpOp2) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 28 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyRowwiseAggregate.

@SuppressWarnings("unchecked")
private static Hop simplifyRowwiseAggregate(Hop parent, Hop hi, int pos) {
    if (hi instanceof AggUnaryOp) {
        AggUnaryOp uhi = (AggUnaryOp) hi;
        Hop input = uhi.getInput().get(0);
        if (HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE)) {
            if (uhi.getDirection() == Direction.Row) {
                if (input.getDim2() == 1) {
                    if (uhi.getOp() == AggOp.VAR) {
                        // For the row variance aggregation, if the input is a column vector,
                        // the row variances will each be zero.
                        // Therefore, perform a rewrite from ROWVAR(X) to a column vector of
                        // zeros.
                        Hop emptyCol = HopRewriteUtils.createDataGenOp(input, uhi, 0);
                        HopRewriteUtils.replaceChildReference(parent, hi, emptyCol, pos);
                        HopRewriteUtils.cleanupUnreferenced(hi, input);
                        // replace current HOP with new empty column HOP
                        hi = emptyCol;
                        LOG.debug("Applied simplifyRowwiseAggregate for rowVars");
                    } else {
                        // All other valid row aggregations over a column vector will result
                        // in the column vector itself.
                        // Therefore, remove unnecessary row aggregation for 1 col
                        HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
                        HopRewriteUtils.cleanupUnreferenced(hi);
                        hi = input;
                        LOG.debug("Applied simplifyRowwiseAggregate1");
                    }
                } else if (input.getDim1() == 1) {
                    // get old parents (before creating cast over aggregate)
                    ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone();
                    // simplify row-aggregate to full aggregate
                    uhi.setDirection(Direction.RowCol);
                    uhi.setDataType(DataType.SCALAR);
                    // create cast to keep same output datatype
                    UnaryOp cast = HopRewriteUtils.createUnary(uhi, OpOp1.CAST_AS_MATRIX);
                    // rehang cast under all parents
                    for (Hop p : parents) {
                        int ix = HopRewriteUtils.getChildReferencePos(p, hi);
                        HopRewriteUtils.replaceChildReference(p, hi, cast, ix);
                    }
                    hi = cast;
                    LOG.debug("Applied simplifyRowwiseAggregate2");
                }
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList)

Example 29 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifySumDiagToTrace.

private static Hop simplifySumDiagToTrace(Hop hi) {
    if (hi instanceof AggUnaryOp) {
        AggUnaryOp au = (AggUnaryOp) hi;
        if (// sum
        au.getOp() == AggOp.SUM && au.getDirection() == Direction.RowCol) {
            Hop hi2 = au.getInput().get(0);
            if (// diagM2V
            hi2 instanceof ReorgOp && ((ReorgOp) hi2).getOp() == ReOrgOp.DIAG && hi2.getDim2() == 1) {
                Hop hi3 = hi2.getInput().get(0);
                // remove diag operator
                HopRewriteUtils.replaceChildReference(au, hi2, hi3, 0);
                HopRewriteUtils.cleanupUnreferenced(hi2);
                // change sum to trace
                au.setOp(AggOp.TRACE);
                LOG.debug("Applied simplifySumDiagToTrace");
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp)

Example 30 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method fuseSumSquared.

/**
 * Replace SUM(X^2) with a fused SUM_SQ(X) HOP.
 *
 * @param parent Parent HOP for which hi is an input.
 * @param hi Current HOP for potential rewrite.
 * @param pos Position of hi in parent's list of inputs.
 *
 * @return Either hi or the rewritten HOP replacing it.
 */
private static Hop fuseSumSquared(Hop parent, Hop hi, int pos) {
    // if SUM
    if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM) {
        Hop sumInput = hi.getInput().get(0);
        // if input to SUM is POW(X,2), and no other consumers of the POW(X,2) HOP
        if (HopRewriteUtils.isBinary(sumInput, OpOp2.POW) && sumInput.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) sumInput.getInput().get(1)) == 2 && sumInput.getParent().size() == 1) {
            Hop x = sumInput.getInput().get(0);
            // if X is NOT a column vector
            if (x.getDim2() > 1) {
                // perform rewrite from SUM(POW(X,2)) to SUM_SQ(X)
                Direction dir = ((AggUnaryOp) hi).getDirection();
                AggUnaryOp sumSq = HopRewriteUtils.createAggUnaryOp(x, AggOp.SUM_SQ, dir);
                HopRewriteUtils.replaceChildReference(parent, hi, sumSq, pos);
                HopRewriteUtils.cleanupUnreferenced(hi, sumInput);
                hi = sumSq;
                LOG.debug("Applied fuseSumSquared.");
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) Direction(org.apache.sysml.hops.Hop.Direction)

Aggregations

AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)36 Hop (org.apache.sysml.hops.Hop)33 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)19 LiteralOp (org.apache.sysml.hops.LiteralOp)14 BinaryOp (org.apache.sysml.hops.BinaryOp)12 ReorgOp (org.apache.sysml.hops.ReorgOp)11 UnaryOp (org.apache.sysml.hops.UnaryOp)11 IndexingOp (org.apache.sysml.hops.IndexingOp)7 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)6 ArrayList (java.util.ArrayList)5 TernaryOp (org.apache.sysml.hops.TernaryOp)5 DataOp (org.apache.sysml.hops.DataOp)4 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)4 CNode (org.apache.sysml.hops.codegen.cplan.CNode)3 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)3 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)3 CNodeUnary (org.apache.sysml.hops.codegen.cplan.CNodeUnary)3 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)2 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)2 CNodeTernary (org.apache.sysml.hops.codegen.cplan.CNodeTernary)2