use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method removeUnnecessaryVectorizeOperation.
private static Hop removeUnnecessaryVectorizeOperation(Hop hi) {
// applies to all binary matrix operations, if one input is unnecessarily vectorized
if (hi instanceof BinaryOp && hi.getDataType() == DataType.MATRIX && ((BinaryOp) hi).supportsMatrixScalarOperations()) {
BinaryOp bop = (BinaryOp) hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
if (// no outer
!(left.getDim1() > 1 && left.getDim2() == 1 && right.getDim1() == 1 && right.getDim2() > 1)) {
// check and remove right vectorized scalar
if (left.getDataType() == DataType.MATRIX && right instanceof DataGenOp) {
DataGenOp dright = (DataGenOp) right;
if (dright.getOp() == DataGenMethod.RAND && dright.hasConstantValue()) {
Hop drightIn = dright.getInput().get(dright.getParamIndex(DataExpression.RAND_MIN));
HopRewriteUtils.replaceChildReference(bop, dright, drightIn, 1);
HopRewriteUtils.cleanupUnreferenced(dright);
LOG.debug("Applied removeUnnecessaryVectorizeOperation1");
}
} else // check and remove left vectorized scalar
if (right.getDataType() == DataType.MATRIX && left instanceof DataGenOp) {
DataGenOp dleft = (DataGenOp) left;
if (dleft.getOp() == DataGenMethod.RAND && dleft.hasConstantValue() && (left.getDim2() == 1 || right.getDim2() > 1) && (left.getDim1() == 1 || right.getDim1() > 1)) {
Hop dleftIn = dleft.getInput().get(dleft.getParamIndex(DataExpression.RAND_MIN));
HopRewriteUtils.replaceChildReference(bop, dleft, dleftIn, 0);
HopRewriteUtils.cleanupUnreferenced(dleft);
LOG.debug("Applied removeUnnecessaryVectorizeOperation2");
}
}
// Note: we applied this rewrite to at most one side in order to keep the
// output semantically equivalent. However, future extensions might consider
// to remove vectors from both side, compute the binary op on scalars and
// finally feed it into a datagenop of the original dimensions.
}
}
return hi;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method fuseBinarySubDAGToUnaryOperation.
/**
* handle simplification of more complex sub DAG to unary operation.
*
* X*(1-X) -> sprop(X)
* (1-X)*X -> sprop(X)
* 1/(1+exp(-X)) -> sigmoid(X)
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
*/
private static Hop fuseBinarySubDAGToUnaryOperation(Hop parent, Hop hi, int pos) {
if (hi instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hi;
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
boolean applied = false;
// sample proportion (sprop) operator
if (bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX) {
if (// (1-X)*X
left instanceof BinaryOp) {
BinaryOp bleft = (BinaryOp) left;
Hop left1 = bleft.getInput().get(0);
Hop left2 = bleft.getInput().get(1);
if (left1 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left1) == 1 && left2 == right && bleft.getOp() == OpOp2.MINUS) {
UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SPROP);
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = unary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop1");
}
}
if (// X*(1-X)
!applied && right instanceof BinaryOp) {
BinaryOp bright = (BinaryOp) right;
Hop right1 = bright.getInput().get(0);
Hop right2 = bright.getInput().get(1);
if (right1 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) right1) == 1 && right2 == left && bright.getOp() == OpOp2.MINUS) {
UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SPROP);
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = unary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop2");
}
}
}
// sigmoid operator
if (!applied && bop.getOp() == OpOp2.DIV && left.getDataType() == DataType.SCALAR && right.getDataType() == DataType.MATRIX && left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left) == 1 && right instanceof BinaryOp) {
// note: if there are multiple consumers on the intermediate,
// we follow the heuristic that redundant computation is more beneficial,
// i.e., we still fuse but leave the intermediate for the other consumers
BinaryOp bop2 = (BinaryOp) right;
Hop left2 = bop2.getInput().get(0);
Hop right2 = bop2.getInput().get(1);
if (bop2.getOp() == OpOp2.PLUS && left2.getDataType() == DataType.SCALAR && right2.getDataType() == DataType.MATRIX && left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left2) == 1 && right2 instanceof UnaryOp) {
UnaryOp uop = (UnaryOp) right2;
Hop uopin = uop.getInput().get(0);
if (uop.getOp() == OpOp1.EXP) {
UnaryOp unary = null;
// Pattern 1: (1/(1 + exp(-X))
if (HopRewriteUtils.isBinary(uopin, OpOp2.MINUS)) {
BinaryOp bop3 = (BinaryOp) uopin;
Hop left3 = bop3.getInput().get(0);
Hop right3 = bop3.getInput().get(1);
if (left3 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left3) == 0)
unary = HopRewriteUtils.createUnary(right3, OpOp1.SIGMOID);
} else // Pattern 2: (1/(1 + exp(X)), e.g., where -(-X) has been removed by
// the 'remove unnecessary minus' rewrite --> reintroduce the minus
{
BinaryOp minus = HopRewriteUtils.createBinaryMinus(uopin);
unary = HopRewriteUtils.createUnary(minus, OpOp1.SIGMOID);
}
if (unary != null) {
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, bop2, uop);
hi = unary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sigmoid1");
}
}
}
}
// select positive (selp) operator (note: same initial pattern as sprop)
if (!applied && bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX) {
// to replace the X*tmp with selp(X) due to lower memory requirements and simply sparsity propagation
if (// (X>0)*X
left instanceof BinaryOp) {
BinaryOp bleft = (BinaryOp) left;
Hop left1 = bleft.getInput().get(0);
Hop left2 = bleft.getInput().get(1);
if (left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left2) == 0 && left1 == right && (bleft.getOp() == OpOp2.GREATER)) {
BinaryOp binary = HopRewriteUtils.createBinary(right, new LiteralOp(0), OpOp2.MAX);
HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = binary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0a");
}
}
if (// X*(X>0)
!applied && right instanceof BinaryOp) {
BinaryOp bright = (BinaryOp) right;
Hop right1 = bright.getInput().get(0);
Hop right2 = bright.getInput().get(1);
if (right2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) right2) == 0 && right1 == left && bright.getOp() == OpOp2.GREATER) {
BinaryOp binary = HopRewriteUtils.createBinary(left, new LiteralOp(0), OpOp2.MAX);
HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = binary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0b");
}
}
}
}
return hi;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifyMultiBinaryToBinaryOperation.
private static Hop simplifyMultiBinaryToBinaryOperation(Hop hi) {
// pattern: 1-(X*Y) --> X 1-* Y (avoid intermediate)
if (HopRewriteUtils.isBinary(hi, OpOp2.MINUS) && hi.getDataType() == DataType.MATRIX && hi.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) hi.getInput().get(0)) == 1 && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) && // single consumer
hi.getInput().get(1).getParent().size() == 1) {
BinaryOp bop = (BinaryOp) hi;
Hop left = hi.getInput().get(1).getInput().get(0);
Hop right = hi.getInput().get(1).getInput().get(1);
// set new binaryop type and rewire inputs
bop.setOp(OpOp2.MINUS1_MULT);
HopRewriteUtils.removeAllChildReferences(hi);
HopRewriteUtils.addChildReference(bop, left);
HopRewriteUtils.addChildReference(bop, right);
LOG.debug("Applied simplifyMultiBinaryToBinaryOperation.");
}
return hi;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class DMLTranslator method processBooleanExpression.
private Hop processBooleanExpression(BooleanExpression source, DataIdentifier target, HashMap<String, Hop> hops) {
// Boolean Not has a single parameter
boolean constLeft = (source.getLeft().getOutput() instanceof ConstIdentifier);
boolean constRight = false;
if (source.getRight() != null) {
constRight = (source.getRight().getOutput() instanceof ConstIdentifier);
}
if (constLeft || constRight) {
LOG.error(source.printErrorLocation() + "Boolean expression with constant unsupported");
throw new RuntimeException(source.printErrorLocation() + "Boolean expression with constant unsupported");
}
Hop left = processExpression(source.getLeft(), null, hops);
Hop right = null;
if (source.getRight() != null) {
right = processExpression(source.getRight(), null, hops);
}
// (type should not be determined by target (e.g., string for print)
if (target == null)
target = createTarget(source);
if (target.getDataType().isScalar())
target.setValueType(ValueType.BOOLEAN);
if (source.getRight() == null) {
Hop currUop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.NOT, left);
currUop.setParseInfo(source);
return currUop;
} else {
Hop currBop = null;
OpOp2 op = null;
if (source.getOpCode() == Expression.BooleanOp.LOGICALAND) {
op = OpOp2.AND;
} else if (source.getOpCode() == Expression.BooleanOp.LOGICALOR) {
op = OpOp2.OR;
} else {
LOG.error(source.printErrorLocation() + "Unknown boolean operation " + source.getOpCode());
throw new RuntimeException(source.printErrorLocation() + "Unknown boolean operation " + source.getOpCode());
}
currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), op, left, right);
currBop.setParseInfo(source);
// setIdentifierParams(currBop,source.getOutput());
return currBop;
}
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class DMLTranslator method processBinaryExpression.
/**
* Construct Hops from parse tree : Process Binary Expression in an
* assignment statement
*
* @param source binary expression
* @param target data identifier
* @param hops map of high-level operators
* @return high-level operator
*/
private Hop processBinaryExpression(BinaryExpression source, DataIdentifier target, HashMap<String, Hop> hops) {
Hop left = processExpression(source.getLeft(), null, hops);
Hop right = processExpression(source.getRight(), null, hops);
if (left == null || right == null) {
left = processExpression(source.getLeft(), null, hops);
right = processExpression(source.getRight(), null, hops);
}
Hop currBop = null;
// (type should not be determined by target (e.g., string for print)
if (target == null) {
target = createTarget(source);
}
target.setValueType(source.getOutput().getValueType());
if (source.getOpCode() == Expression.BinaryOp.PLUS) {
currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.PLUS, left, right);
} else if (source.getOpCode() == Expression.BinaryOp.MINUS) {
currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MINUS, left, right);
} else if (source.getOpCode() == Expression.BinaryOp.MULT) {
currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MULT, left, right);
} else if (source.getOpCode() == Expression.BinaryOp.DIV) {
currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.DIV, left, right);
} else if (source.getOpCode() == Expression.BinaryOp.MODULUS) {
currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MODULUS, left, right);
} else if (source.getOpCode() == Expression.BinaryOp.INTDIV) {
currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.INTDIV, left, right);
} else if (source.getOpCode() == Expression.BinaryOp.MATMULT) {
currBop = new AggBinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MULT, AggOp.SUM, left, right);
} else if (source.getOpCode() == Expression.BinaryOp.POW) {
currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.POW, left, right);
} else {
throw new ParseException("Unsupported parsing of binary expression: " + source.getOpCode());
}
setIdentifierParams(currBop, source.getOutput());
currBop.setParseInfo(source);
return currBop;
}
Aggregations