use of org.apache.sysml.hops.UnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyDotProductSum.
/**
* NOTE: dot-product-sum could be also applied to sum(a*b). However, we
* restrict ourselfs to sum(a^2) and transitively sum(a*a) since a general mm
* a%*%b on MR can be also counter-productive (e.g., MMCJ) while tsmm is always
* beneficial.
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
private static Hop simplifyDotProductSum(Hop parent, Hop hi, int pos) {
// w/o materialization of intermediates
if (// sum
hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM && // full aggregate
((AggUnaryOp) hi).getDirection() == Direction.RowCol && // vector (for correctness)
hi.getInput().get(0).getDim2() == 1) {
Hop baLeft = null;
Hop baRight = null;
// check for ^2 w/o multiple consumers
Hop hi2 = hi.getInput().get(0);
// check for sum(v^2), might have been rewritten from sum(v*v)
if (HopRewriteUtils.isBinary(hi2, OpOp2.POW) && hi2.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) hi2.getInput().get(1)) == 2 && // no other consumer than sum
hi2.getParent().size() == 1) {
Hop input = hi2.getInput().get(0);
baLeft = input;
baRight = input;
} else // check for sum(v1*v2), but prevent to rewrite sum(v1*v2*v3) which is later compiled into a ta+* lop
if (// no other consumer than sum
HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1) && hi2.getInput().get(0).getDim2() == 1 && hi2.getInput().get(1).getDim2() == 1 && !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT) && !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT) && (!ALLOW_SUM_PRODUCT_REWRITES || !(// do not rewrite (A^2)*B
HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW) && // let tak+* handle it
hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp && ((LiteralOp) hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2)) && (!ALLOW_SUM_PRODUCT_REWRITES || !(// do not rewrite B*(A^2)
HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW) && // let tak+* handle it
hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp && ((LiteralOp) hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2))) {
baLeft = hi2.getInput().get(0);
baRight = hi2.getInput().get(1);
}
// perform actual rewrite (if necessary)
if (baLeft != null && baRight != null) {
// create new operator chain
ReorgOp trans = HopRewriteUtils.createTranspose(baLeft);
AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, baRight);
UnaryOp cast = HopRewriteUtils.createUnary(mmult, OpOp1.CAST_AS_SCALAR);
// rehang new subdag under parent node
HopRewriteUtils.replaceChildReference(parent, hi, cast, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
hi = cast;
LOG.debug("Applied simplifyDotProductSum.");
}
}
return hi;
}
use of org.apache.sysml.hops.UnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyNrowNcolComputation.
private static Hop simplifyNrowNcolComputation(Hop parent, Hop hi, int pos) {
// even if the intermediate is otherwise not required, e.g., when part of a fused operator)
if (hi instanceof UnaryOp) {
if (((UnaryOp) hi).getOp() == OpOp1.NROW && hi.getInput().get(0).rowsKnown()) {
Hop hnew = new LiteralOp(hi.getInput().get(0).getDim1());
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos, false);
HopRewriteUtils.cleanupUnreferenced(hi);
LOG.debug("Applied simplifyNrowComputation nrow(" + hi.getHopID() + ") -> " + hnew.getName() + " (line " + hi.getBeginLine() + ").");
hi = hnew;
} else if (((UnaryOp) hi).getOp() == OpOp1.NCOL && hi.getInput().get(0).colsKnown()) {
Hop hnew = new LiteralOp(hi.getInput().get(0).getDim2());
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos, false);
HopRewriteUtils.cleanupUnreferenced(hi);
LOG.debug("Applied simplifyNcolComputation ncol(" + hi.getHopID() + ") -> " + hnew.getName() + " (line " + hi.getBeginLine() + ").");
hi = hnew;
}
}
return hi;
}
use of org.apache.sysml.hops.UnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifyUnaryPPredOperation.
private static Hop simplifyUnaryPPredOperation(Hop parent, Hop hi, int pos) {
if (// unaryop
hi instanceof UnaryOp && hi.getDataType() == DataType.MATRIX && // binaryop - ppred
hi.getInput().get(0) instanceof BinaryOp && ((BinaryOp) hi.getInput().get(0)).isPPredOperation()) {
// valid unary op
UnaryOp uop = (UnaryOp) hi;
if (uop.getOp() == OpOp1.ABS || uop.getOp() == OpOp1.SIGN || uop.getOp() == OpOp1.CEIL || uop.getOp() == OpOp1.FLOOR || uop.getOp() == OpOp1.ROUND) {
// clear link unary-binary
Hop input = uop.getInput().get(0);
HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = input;
LOG.debug("Applied simplifyUnaryPPredOperation.");
}
}
return hi;
}
use of org.apache.sysml.hops.UnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifyBinaryToUnaryOperation.
/**
* Handle simplification of binary operations (relies on previous common subexpression elimination).
* At the same time this servers as a canonicalization for more complex rewrites.
*
* X+X -> X*2, X*X -> X^2, (X>0)-(X<0) -> sign(X)
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
private static Hop simplifyBinaryToUnaryOperation(Hop parent, Hop hi, int pos) {
if (hi instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hi;
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
// patterns: X+X -> X*2, X*X -> X^2,
if (left == right && left.getDataType() == DataType.MATRIX) {
// however, we later compile specific LOPS for X*2 and X^2
if (// X+X -> X*2
bop.getOp() == OpOp2.PLUS) {
bop.setOp(OpOp2.MULT);
HopRewriteUtils.replaceChildReference(hi, right, new LiteralOp(2), 1);
LOG.debug("Applied simplifyBinaryToUnaryOperation1 (line " + hi.getBeginLine() + ").");
} else if (// X*X -> X^2
bop.getOp() == OpOp2.MULT) {
bop.setOp(OpOp2.POW);
HopRewriteUtils.replaceChildReference(hi, right, new LiteralOp(2), 1);
LOG.debug("Applied simplifyBinaryToUnaryOperation2 (line " + hi.getBeginLine() + ").");
}
} else // patterns: (X>0)-(X<0) -> sign(X)
if (bop.getOp() == OpOp2.MINUS && HopRewriteUtils.isBinary(left, OpOp2.GREATER) && HopRewriteUtils.isBinary(right, OpOp2.LESS) && left.getInput().get(0) == right.getInput().get(0) && left.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left.getInput().get(1)) == 0 && right.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) right.getInput().get(1)) == 0) {
UnaryOp uop = HopRewriteUtils.createUnary(left.getInput().get(0), OpOp1.SIGN);
HopRewriteUtils.replaceChildReference(parent, hi, uop, pos);
HopRewriteUtils.cleanupUnreferenced(hi, left, right);
hi = uop;
LOG.debug("Applied simplifyBinaryToUnaryOperation3 (line " + hi.getBeginLine() + ").");
}
}
return hi;
}
use of org.apache.sysml.hops.UnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method fuseBinarySubDAGToUnaryOperation.
/**
* handle simplification of more complex sub DAG to unary operation.
*
* X*(1-X) -> sprop(X)
* (1-X)*X -> sprop(X)
* 1/(1+exp(-X)) -> sigmoid(X)
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
*/
private static Hop fuseBinarySubDAGToUnaryOperation(Hop parent, Hop hi, int pos) {
if (hi instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hi;
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
boolean applied = false;
// sample proportion (sprop) operator
if (bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX) {
if (// (1-X)*X
left instanceof BinaryOp) {
BinaryOp bleft = (BinaryOp) left;
Hop left1 = bleft.getInput().get(0);
Hop left2 = bleft.getInput().get(1);
if (left1 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left1) == 1 && left2 == right && bleft.getOp() == OpOp2.MINUS) {
UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SPROP);
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = unary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop1");
}
}
if (// X*(1-X)
!applied && right instanceof BinaryOp) {
BinaryOp bright = (BinaryOp) right;
Hop right1 = bright.getInput().get(0);
Hop right2 = bright.getInput().get(1);
if (right1 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) right1) == 1 && right2 == left && bright.getOp() == OpOp2.MINUS) {
UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SPROP);
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = unary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop2");
}
}
}
// sigmoid operator
if (!applied && bop.getOp() == OpOp2.DIV && left.getDataType() == DataType.SCALAR && right.getDataType() == DataType.MATRIX && left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left) == 1 && right instanceof BinaryOp) {
// note: if there are multiple consumers on the intermediate,
// we follow the heuristic that redundant computation is more beneficial,
// i.e., we still fuse but leave the intermediate for the other consumers
BinaryOp bop2 = (BinaryOp) right;
Hop left2 = bop2.getInput().get(0);
Hop right2 = bop2.getInput().get(1);
if (bop2.getOp() == OpOp2.PLUS && left2.getDataType() == DataType.SCALAR && right2.getDataType() == DataType.MATRIX && left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left2) == 1 && right2 instanceof UnaryOp) {
UnaryOp uop = (UnaryOp) right2;
Hop uopin = uop.getInput().get(0);
if (uop.getOp() == OpOp1.EXP) {
UnaryOp unary = null;
// Pattern 1: (1/(1 + exp(-X))
if (HopRewriteUtils.isBinary(uopin, OpOp2.MINUS)) {
BinaryOp bop3 = (BinaryOp) uopin;
Hop left3 = bop3.getInput().get(0);
Hop right3 = bop3.getInput().get(1);
if (left3 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left3) == 0)
unary = HopRewriteUtils.createUnary(right3, OpOp1.SIGMOID);
} else // Pattern 2: (1/(1 + exp(X)), e.g., where -(-X) has been removed by
// the 'remove unnecessary minus' rewrite --> reintroduce the minus
{
BinaryOp minus = HopRewriteUtils.createBinaryMinus(uopin);
unary = HopRewriteUtils.createUnary(minus, OpOp1.SIGMOID);
}
if (unary != null) {
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, bop2, uop);
hi = unary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sigmoid1");
}
}
}
}
// select positive (selp) operator (note: same initial pattern as sprop)
if (!applied && bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX) {
// to replace the X*tmp with selp(X) due to lower memory requirements and simply sparsity propagation
if (// (X>0)*X
left instanceof BinaryOp) {
BinaryOp bleft = (BinaryOp) left;
Hop left1 = bleft.getInput().get(0);
Hop left2 = bleft.getInput().get(1);
if (left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left2) == 0 && left1 == right && (bleft.getOp() == OpOp2.GREATER)) {
BinaryOp binary = HopRewriteUtils.createBinary(right, new LiteralOp(0), OpOp2.MAX);
HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = binary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0a");
}
}
if (// X*(X>0)
!applied && right instanceof BinaryOp) {
BinaryOp bright = (BinaryOp) right;
Hop right1 = bright.getInput().get(0);
Hop right2 = bright.getInput().get(1);
if (right2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) right2) == 0 && right1 == left && bright.getOp() == OpOp2.GREATER) {
BinaryOp binary = HopRewriteUtils.createBinary(left, new LiteralOp(0), OpOp2.MAX);
HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = binary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0b");
}
}
}
}
return hi;
}
Aggregations