use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyColSumsMVMult.
private static Hop simplifyColSumsMVMult(Hop parent, Hop hi, int pos) {
// removed by other rewrite if unnecessary, i.e., if Y==t(Z)
if (hi instanceof AggUnaryOp) {
AggUnaryOp uhi = (AggUnaryOp) hi;
Hop input = uhi.getInput().get(0);
if (// colsums
uhi.getOp() == AggOp.SUM && uhi.getDirection() == Direction.Col && // b(*)
HopRewriteUtils.isBinary(input, OpOp2.MULT)) {
Hop left = input.getInput().get(0);
Hop right = input.getInput().get(1);
if (left.getDim1() > 1 && left.getDim2() > 1 && right.getDim1() > 1 && // MV (col vector)
right.getDim2() == 1) {
// create new operators
ReorgOp trans = HopRewriteUtils.createTranspose(right);
AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, left);
// relink new child
HopRewriteUtils.replaceChildReference(parent, hi, mmult, pos);
HopRewriteUtils.cleanupUnreferenced(uhi, input);
hi = mmult;
LOG.debug("Applied simplifyColSumsMVMult");
}
}
}
return hi;
}
use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyDotProductSum.
/**
* NOTE: dot-product-sum could be also applied to sum(a*b). However, we
* restrict ourselfs to sum(a^2) and transitively sum(a*a) since a general mm
* a%*%b on MR can be also counter-productive (e.g., MMCJ) while tsmm is always
* beneficial.
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
private static Hop simplifyDotProductSum(Hop parent, Hop hi, int pos) {
// w/o materialization of intermediates
if (// sum
hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM && // full aggregate
((AggUnaryOp) hi).getDirection() == Direction.RowCol && // vector (for correctness)
hi.getInput().get(0).getDim2() == 1) {
Hop baLeft = null;
Hop baRight = null;
// check for ^2 w/o multiple consumers
Hop hi2 = hi.getInput().get(0);
// check for sum(v^2), might have been rewritten from sum(v*v)
if (HopRewriteUtils.isBinary(hi2, OpOp2.POW) && hi2.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) hi2.getInput().get(1)) == 2 && // no other consumer than sum
hi2.getParent().size() == 1) {
Hop input = hi2.getInput().get(0);
baLeft = input;
baRight = input;
} else // check for sum(v1*v2), but prevent to rewrite sum(v1*v2*v3) which is later compiled into a ta+* lop
if (// no other consumer than sum
HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1) && hi2.getInput().get(0).getDim2() == 1 && hi2.getInput().get(1).getDim2() == 1 && !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT) && !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT) && (!ALLOW_SUM_PRODUCT_REWRITES || !(// do not rewrite (A^2)*B
HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW) && // let tak+* handle it
hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp && ((LiteralOp) hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2)) && (!ALLOW_SUM_PRODUCT_REWRITES || !(// do not rewrite B*(A^2)
HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW) && // let tak+* handle it
hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp && ((LiteralOp) hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2))) {
baLeft = hi2.getInput().get(0);
baRight = hi2.getInput().get(1);
}
// perform actual rewrite (if necessary)
if (baLeft != null && baRight != null) {
// create new operator chain
ReorgOp trans = HopRewriteUtils.createTranspose(baLeft);
AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, baRight);
UnaryOp cast = HopRewriteUtils.createUnary(mmult, OpOp1.CAST_AS_SCALAR);
// rehang new subdag under parent node
HopRewriteUtils.replaceChildReference(parent, hi, cast, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
hi = cast;
LOG.debug("Applied simplifyDotProductSum.");
}
}
return hi;
}
use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyDiagMatrixMult.
private static Hop simplifyDiagMatrixMult(Hop parent, Hop hi, int pos) {
if (// diagM2V
hi instanceof ReorgOp && ((ReorgOp) hi).getOp() == ReOrgOp.DIAG && hi.getDim2() == 1) {
Hop hi2 = hi.getInput().get(0);
if (// X%*%Y
HopRewriteUtils.isMatrixMultiply(hi2)) {
Hop left = hi2.getInput().get(0);
Hop right = hi2.getInput().get(1);
// create new operators (incl refresh size inside for transpose)
ReorgOp trans = HopRewriteUtils.createTranspose(right);
BinaryOp mult = HopRewriteUtils.createBinary(left, trans, OpOp2.MULT);
AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(mult, AggOp.SUM, Direction.Row);
// rehang new subdag under parent node
HopRewriteUtils.replaceChildReference(parent, hi, rowSum, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
hi = rowSum;
LOG.debug("Applied simplifyDiagMatrixMult");
}
}
return hi;
}
use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method pushdownSumBinaryMult.
private static Hop pushdownSumBinaryMult(Hop parent, Hop hi, int pos) {
// pattern: sum(lamda*X) -> lamda*sum(X)
if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // only one parent which is the sum
((AggUnaryOp) hi).getOp() == Hop.AggOp.SUM && HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.MULT, 1) && ((hi.getInput().get(0).getInput().get(0).getDataType() == DataType.SCALAR && hi.getInput().get(0).getInput().get(1).getDataType() == DataType.MATRIX) || (hi.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getInput().get(1).getDataType() == DataType.SCALAR))) {
Hop operand1 = hi.getInput().get(0).getInput().get(0);
Hop operand2 = hi.getInput().get(0).getInput().get(1);
// check which operand is the Scalar and which is the matrix
Hop lamda = (operand1.getDataType() == DataType.SCALAR) ? operand1 : operand2;
Hop matrix = (operand1.getDataType() == DataType.MATRIX) ? operand1 : operand2;
AggUnaryOp aggOp = HopRewriteUtils.createAggUnaryOp(matrix, AggOp.SUM, Direction.RowCol);
Hop bop = HopRewriteUtils.createBinary(lamda, aggOp, OpOp2.MULT);
HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
LOG.debug("Applied pushdownSumBinaryMult.");
return bop;
}
return hi;
}
use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method pushdownUnaryAggTransposeOperation.
private static Hop pushdownUnaryAggTransposeOperation(Hop parent, Hop hi, int pos) {
if (hi instanceof AggUnaryOp && hi.getParent().size() == 1 && (((AggUnaryOp) hi).getDirection() == Direction.Row || ((AggUnaryOp) hi).getDirection() == Direction.Col) && HopRewriteUtils.isTransposeOperation(hi.getInput().get(0), 1) && HopRewriteUtils.isValidOp(((AggUnaryOp) hi).getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE)) {
AggUnaryOp uagg = (AggUnaryOp) hi;
// get input rewire existing operators (remove inner transpose)
Hop input = uagg.getInput().get(0).getInput().get(0);
HopRewriteUtils.removeAllChildReferences(hi.getInput().get(0));
HopRewriteUtils.removeAllChildReferences(hi);
HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
// pattern 1: row-aggregate to col aggregate, e.g., rowSums(t(X))->t(colSums(X))
if (uagg.getDirection() == Direction.Row) {
uagg.setDirection(Direction.Col);
LOG.debug("Applied pushdownUnaryAggTransposeOperation1 (line " + hi.getBeginLine() + ").");
} else // pattern 2: col-aggregate to row aggregate, e.g., colSums(t(X))->t(rowSums(X))
if (uagg.getDirection() == Direction.Col) {
uagg.setDirection(Direction.Row);
LOG.debug("Applied pushdownUnaryAggTransposeOperation2 (line " + hi.getBeginLine() + ").");
}
// create outer transpose operation and rewire operators
HopRewriteUtils.addChildReference(uagg, input);
uagg.refreshSizeInformation();
// incl refresh size
Hop trans = HopRewriteUtils.createTranspose(uagg);
// by def, same size
HopRewriteUtils.addChildReference(parent, trans, pos);
hi = trans;
}
return hi;
}
Aggregations