Search in sources :

Example 1 with Direction

use of org.apache.sysml.hops.Hop.Direction in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method fuseSumSquared.

/**
 * Replace SUM(X^2) with a fused SUM_SQ(X) HOP.
 *
 * @param parent Parent HOP for which hi is an input.
 * @param hi Current HOP for potential rewrite.
 * @param pos Position of hi in parent's list of inputs.
 *
 * @return Either hi or the rewritten HOP replacing it.
 */
private static Hop fuseSumSquared(Hop parent, Hop hi, int pos) {
    // if SUM
    if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM) {
        Hop sumInput = hi.getInput().get(0);
        // if input to SUM is POW(X,2), and no other consumers of the POW(X,2) HOP
        if (HopRewriteUtils.isBinary(sumInput, OpOp2.POW) && sumInput.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) sumInput.getInput().get(1)) == 2 && sumInput.getParent().size() == 1) {
            Hop x = sumInput.getInput().get(0);
            // if X is NOT a column vector
            if (x.getDim2() > 1) {
                // perform rewrite from SUM(POW(X,2)) to SUM_SQ(X)
                Direction dir = ((AggUnaryOp) hi).getDirection();
                AggUnaryOp sumSq = HopRewriteUtils.createAggUnaryOp(x, AggOp.SUM_SQ, dir);
                HopRewriteUtils.replaceChildReference(parent, hi, sumSq, pos);
                HopRewriteUtils.cleanupUnreferenced(hi, sumInput);
                hi = sumSq;
                LOG.debug("Applied fuseSumSquared.");
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) Direction(org.apache.sysml.hops.Hop.Direction)

Example 2 with Direction

use of org.apache.sysml.hops.Hop.Direction in project systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method fuseSumSquared.

/**
 * Replace SUM(X^2) with a fused SUM_SQ(X) HOP.
 *
 * @param parent Parent HOP for which hi is an input.
 * @param hi Current HOP for potential rewrite.
 * @param pos Position of hi in parent's list of inputs.
 *
 * @return Either hi or the rewritten HOP replacing it.
 */
private static Hop fuseSumSquared(Hop parent, Hop hi, int pos) {
    // if SUM
    if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM) {
        Hop sumInput = hi.getInput().get(0);
        // if input to SUM is POW(X,2), and no other consumers of the POW(X,2) HOP
        if (HopRewriteUtils.isBinary(sumInput, OpOp2.POW) && sumInput.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) sumInput.getInput().get(1)) == 2 && sumInput.getParent().size() == 1) {
            Hop x = sumInput.getInput().get(0);
            // if X is NOT a column vector
            if (x.getDim2() > 1) {
                // perform rewrite from SUM(POW(X,2)) to SUM_SQ(X)
                Direction dir = ((AggUnaryOp) hi).getDirection();
                AggUnaryOp sumSq = HopRewriteUtils.createAggUnaryOp(x, AggOp.SUM_SQ, dir);
                HopRewriteUtils.replaceChildReference(parent, hi, sumSq, pos);
                HopRewriteUtils.cleanupUnreferenced(hi, sumInput);
                hi = sumSq;
                LOG.debug("Applied fuseSumSquared.");
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) Direction(org.apache.sysml.hops.Hop.Direction)

Aggregations

AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)2 Hop (org.apache.sysml.hops.Hop)2 Direction (org.apache.sysml.hops.Hop.Direction)2 LiteralOp (org.apache.sysml.hops.LiteralOp)2