use of org.apache.sysml.hops.Hop.Direction in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method fuseSumSquared.
/**
* Replace SUM(X^2) with a fused SUM_SQ(X) HOP.
*
* @param parent Parent HOP for which hi is an input.
* @param hi Current HOP for potential rewrite.
* @param pos Position of hi in parent's list of inputs.
*
* @return Either hi or the rewritten HOP replacing it.
*/
private static Hop fuseSumSquared(Hop parent, Hop hi, int pos) {
// if SUM
if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM) {
Hop sumInput = hi.getInput().get(0);
// if input to SUM is POW(X,2), and no other consumers of the POW(X,2) HOP
if (HopRewriteUtils.isBinary(sumInput, OpOp2.POW) && sumInput.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) sumInput.getInput().get(1)) == 2 && sumInput.getParent().size() == 1) {
Hop x = sumInput.getInput().get(0);
// if X is NOT a column vector
if (x.getDim2() > 1) {
// perform rewrite from SUM(POW(X,2)) to SUM_SQ(X)
Direction dir = ((AggUnaryOp) hi).getDirection();
AggUnaryOp sumSq = HopRewriteUtils.createAggUnaryOp(x, AggOp.SUM_SQ, dir);
HopRewriteUtils.replaceChildReference(parent, hi, sumSq, pos);
HopRewriteUtils.cleanupUnreferenced(hi, sumInput);
hi = sumSq;
LOG.debug("Applied fuseSumSquared.");
}
}
}
return hi;
}
use of org.apache.sysml.hops.Hop.Direction in project systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method fuseSumSquared.
/**
* Replace SUM(X^2) with a fused SUM_SQ(X) HOP.
*
* @param parent Parent HOP for which hi is an input.
* @param hi Current HOP for potential rewrite.
* @param pos Position of hi in parent's list of inputs.
*
* @return Either hi or the rewritten HOP replacing it.
*/
private static Hop fuseSumSquared(Hop parent, Hop hi, int pos) {
// if SUM
if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM) {
Hop sumInput = hi.getInput().get(0);
// if input to SUM is POW(X,2), and no other consumers of the POW(X,2) HOP
if (HopRewriteUtils.isBinary(sumInput, OpOp2.POW) && sumInput.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) sumInput.getInput().get(1)) == 2 && sumInput.getParent().size() == 1) {
Hop x = sumInput.getInput().get(0);
// if X is NOT a column vector
if (x.getDim2() > 1) {
// perform rewrite from SUM(POW(X,2)) to SUM_SQ(X)
Direction dir = ((AggUnaryOp) hi).getDirection();
AggUnaryOp sumSq = HopRewriteUtils.createAggUnaryOp(x, AggOp.SUM_SQ, dir);
HopRewriteUtils.replaceChildReference(parent, hi, sumSq, pos);
HopRewriteUtils.cleanupUnreferenced(hi, sumInput);
hi = sumSq;
LOG.debug("Applied fuseSumSquared.");
}
}
}
return hi;
}
Aggregations