use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyEmptyAggregate.
private static Hop simplifyEmptyAggregate(Hop parent, Hop hi, int pos) {
if (hi instanceof AggUnaryOp) {
AggUnaryOp uhi = (AggUnaryOp) hi;
Hop input = uhi.getInput().get(0);
// check for valid empty aggregates, except for matrices with zero rows/cols
if (HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_EMPTY_AGGREGATE) && HopRewriteUtils.isEmpty(input) && input.getDim1() >= 1 && input.getDim2() >= 1) {
Hop hnew = null;
if (uhi.getDirection() == Direction.RowCol)
hnew = new LiteralOp(0.0);
else if (uhi.getDirection() == Direction.Col)
// nrow(uhi)=1
hnew = HopRewriteUtils.createDataGenOp(uhi, input, 0);
else
// if( uhi.getDirection() == Direction.Row )
// ncol(uhi)=1
hnew = HopRewriteUtils.createDataGenOp(input, uhi, 0);
// add new child to parent input
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
LOG.debug("Applied simplifyEmptyAggregate");
}
}
return hi;
}
use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method pushdownSumOnAdditiveBinary.
/**
* patterns: sum(A+B)->sum(A)+sum(B); sum(A-B)->sum(A)-sum(B)
*
* @param parent the parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos) {
// all patterns headed by full sum over binary operation
if (// full sum root over binaryop
hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && ((AggUnaryOp) hi).getOp() == AggOp.SUM && hi.getInput().get(0) instanceof BinaryOp && // single parent
hi.getInput().get(0).getParent().size() == 1) {
BinaryOp bop = (BinaryOp) hi.getInput().get(0);
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
if (// dims(A) == dims(B)
HopRewriteUtils.isEqualSize(left, right) && left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX) {
OpOp2 applyOp = (// pattern a: sum(A+B)->sum(A)+sum(B)
bop.getOp() == OpOp2.PLUS || // pattern b: sum(A-B)->sum(A)-sum(B)
bop.getOp() == OpOp2.MINUS) ? bop.getOp() : null;
if (applyOp != null) {
// create new subdag sum(A) bop sum(B)
AggUnaryOp sum1 = HopRewriteUtils.createSum(left);
AggUnaryOp sum2 = HopRewriteUtils.createSum(right);
BinaryOp newBin = HopRewriteUtils.createBinary(sum1, sum2, applyOp);
// rewire new subdag
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
HopRewriteUtils.cleanupUnreferenced(hi, bop);
hi = newBin;
LOG.debug("Applied pushdownSumOnAdditiveBinary.");
}
}
}
return hi;
}
use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyRowwiseAggregate.
@SuppressWarnings("unchecked")
private static Hop simplifyRowwiseAggregate(Hop parent, Hop hi, int pos) {
if (hi instanceof AggUnaryOp) {
AggUnaryOp uhi = (AggUnaryOp) hi;
Hop input = uhi.getInput().get(0);
if (HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE)) {
if (uhi.getDirection() == Direction.Row) {
if (input.getDim2() == 1) {
if (uhi.getOp() == AggOp.VAR) {
// For the row variance aggregation, if the input is a column vector,
// the row variances will each be zero.
// Therefore, perform a rewrite from ROWVAR(X) to a column vector of
// zeros.
Hop emptyCol = HopRewriteUtils.createDataGenOp(input, uhi, 0);
HopRewriteUtils.replaceChildReference(parent, hi, emptyCol, pos);
HopRewriteUtils.cleanupUnreferenced(hi, input);
// replace current HOP with new empty column HOP
hi = emptyCol;
LOG.debug("Applied simplifyRowwiseAggregate for rowVars");
} else {
// All other valid row aggregations over a column vector will result
// in the column vector itself.
// Therefore, remove unnecessary row aggregation for 1 col
HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = input;
LOG.debug("Applied simplifyRowwiseAggregate1");
}
} else if (input.getDim1() == 1) {
// get old parents (before creating cast over aggregate)
ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone();
// simplify row-aggregate to full aggregate
uhi.setDirection(Direction.RowCol);
uhi.setDataType(DataType.SCALAR);
// create cast to keep same output datatype
UnaryOp cast = HopRewriteUtils.createUnary(uhi, OpOp1.CAST_AS_MATRIX);
// rehang cast under all parents
for (Hop p : parents) {
int ix = HopRewriteUtils.getChildReferencePos(p, hi);
HopRewriteUtils.replaceChildReference(p, hi, cast, ix);
}
hi = cast;
LOG.debug("Applied simplifyRowwiseAggregate2");
}
}
}
}
return hi;
}
use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifySumDiagToTrace.
private static Hop simplifySumDiagToTrace(Hop hi) {
if (hi instanceof AggUnaryOp) {
AggUnaryOp au = (AggUnaryOp) hi;
if (// sum
au.getOp() == AggOp.SUM && au.getDirection() == Direction.RowCol) {
Hop hi2 = au.getInput().get(0);
if (// diagM2V
hi2 instanceof ReorgOp && ((ReorgOp) hi2).getOp() == ReOrgOp.DIAG && hi2.getDim2() == 1) {
Hop hi3 = hi2.getInput().get(0);
// remove diag operator
HopRewriteUtils.replaceChildReference(au, hi2, hi3, 0);
HopRewriteUtils.cleanupUnreferenced(hi2);
// change sum to trace
au.setOp(AggOp.TRACE);
LOG.debug("Applied simplifySumDiagToTrace");
}
}
}
return hi;
}
use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method fuseSumSquared.
/**
* Replace SUM(X^2) with a fused SUM_SQ(X) HOP.
*
* @param parent Parent HOP for which hi is an input.
* @param hi Current HOP for potential rewrite.
* @param pos Position of hi in parent's list of inputs.
*
* @return Either hi or the rewritten HOP replacing it.
*/
private static Hop fuseSumSquared(Hop parent, Hop hi, int pos) {
// if SUM
if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM) {
Hop sumInput = hi.getInput().get(0);
// if input to SUM is POW(X,2), and no other consumers of the POW(X,2) HOP
if (HopRewriteUtils.isBinary(sumInput, OpOp2.POW) && sumInput.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) sumInput.getInput().get(1)) == 2 && sumInput.getParent().size() == 1) {
Hop x = sumInput.getInput().get(0);
// if X is NOT a column vector
if (x.getDim2() > 1) {
// perform rewrite from SUM(POW(X,2)) to SUM_SQ(X)
Direction dir = ((AggUnaryOp) hi).getDirection();
AggUnaryOp sumSq = HopRewriteUtils.createAggUnaryOp(x, AggOp.SUM_SQ, dir);
HopRewriteUtils.replaceChildReference(parent, hi, sumSq, pos);
HopRewriteUtils.cleanupUnreferenced(hi, sumInput);
hi = sumSq;
LOG.debug("Applied fuseSumSquared.");
}
}
}
return hi;
}
Aggregations