use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifyDistributiveBinaryOperation.
/**
* (X-Y*X) -> (1-Y)*X, (Y*X-X) -> (Y-1)*X
* (X+Y*X) -> (1+Y)*X, (Y*X+X) -> (Y+1)*X
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
private static Hop simplifyDistributiveBinaryOperation(Hop parent, Hop hi, int pos) {
if (hi instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
// (X+Y*X) -> (1+Y)*X, (Y*X+X) -> (Y+1)*X
// (X-Y*X) -> (1-Y)*X, (Y*X-X) -> (Y-1)*X
boolean applied = false;
if (left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX && HopRewriteUtils.isValidOp(bop.getOp(), LOOKUP_VALID_DISTRIBUTIVE_BINARY)) {
Hop X = null;
Hop Y = null;
if (// (Y*X-X) -> (Y-1)*X
HopRewriteUtils.isBinary(left, OpOp2.MULT)) {
Hop leftC1 = left.getInput().get(0);
Hop leftC2 = left.getInput().get(1);
if (leftC1.getDataType() == DataType.MATRIX && leftC2.getDataType() == DataType.MATRIX && (right == leftC1 || right == leftC2) && leftC1 != leftC2) {
// any mult order
X = right;
Y = (right == leftC1) ? leftC2 : leftC1;
}
if (X != null) {
// rewrite 'binary +/-'
LiteralOp literal = new LiteralOp(1);
BinaryOp plus = HopRewriteUtils.createBinary(Y, literal, bop.getOp());
BinaryOp mult = HopRewriteUtils.createBinary(plus, X, OpOp2.MULT);
HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
HopRewriteUtils.cleanupUnreferenced(hi, left);
hi = mult;
applied = true;
LOG.debug("Applied simplifyDistributiveBinaryOperation1");
}
}
if (// (X-Y*X) -> (1-Y)*X
!applied && HopRewriteUtils.isBinary(right, OpOp2.MULT)) {
Hop rightC1 = right.getInput().get(0);
Hop rightC2 = right.getInput().get(1);
if (rightC1.getDataType() == DataType.MATRIX && rightC2.getDataType() == DataType.MATRIX && (left == rightC1 || left == rightC2) && rightC1 != rightC2) {
// any mult order
X = left;
Y = (left == rightC1) ? rightC2 : rightC1;
}
if (X != null) {
// rewrite '+/- binary'
LiteralOp literal = new LiteralOp(1);
BinaryOp plus = HopRewriteUtils.createBinary(literal, Y, bop.getOp());
BinaryOp mult = HopRewriteUtils.createBinary(plus, X, OpOp2.MULT);
HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
HopRewriteUtils.cleanupUnreferenced(hi, right);
hi = mult;
LOG.debug("Applied simplifyDistributiveBinaryOperation2");
}
}
}
}
return hi;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method fuseBinarySubDAGToUnaryOperation.
/**
* handle simplification of more complex sub DAG to unary operation.
*
* X*(1-X) -> sprop(X)
* (1-X)*X -> sprop(X)
* 1/(1+exp(-X)) -> sigmoid(X)
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
*/
private static Hop fuseBinarySubDAGToUnaryOperation(Hop parent, Hop hi, int pos) {
if (hi instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hi;
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
boolean applied = false;
// sample proportion (sprop) operator
if (bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX) {
if (// (1-X)*X
left instanceof BinaryOp) {
BinaryOp bleft = (BinaryOp) left;
Hop left1 = bleft.getInput().get(0);
Hop left2 = bleft.getInput().get(1);
if (left1 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left1) == 1 && left2 == right && bleft.getOp() == OpOp2.MINUS) {
UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SPROP);
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = unary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop1");
}
}
if (// X*(1-X)
!applied && right instanceof BinaryOp) {
BinaryOp bright = (BinaryOp) right;
Hop right1 = bright.getInput().get(0);
Hop right2 = bright.getInput().get(1);
if (right1 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) right1) == 1 && right2 == left && bright.getOp() == OpOp2.MINUS) {
UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SPROP);
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = unary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop2");
}
}
}
// sigmoid operator
if (!applied && bop.getOp() == OpOp2.DIV && left.getDataType() == DataType.SCALAR && right.getDataType() == DataType.MATRIX && left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left) == 1 && right instanceof BinaryOp) {
// note: if there are multiple consumers on the intermediate,
// we follow the heuristic that redundant computation is more beneficial,
// i.e., we still fuse but leave the intermediate for the other consumers
BinaryOp bop2 = (BinaryOp) right;
Hop left2 = bop2.getInput().get(0);
Hop right2 = bop2.getInput().get(1);
if (bop2.getOp() == OpOp2.PLUS && left2.getDataType() == DataType.SCALAR && right2.getDataType() == DataType.MATRIX && left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left2) == 1 && right2 instanceof UnaryOp) {
UnaryOp uop = (UnaryOp) right2;
Hop uopin = uop.getInput().get(0);
if (uop.getOp() == OpOp1.EXP) {
UnaryOp unary = null;
// Pattern 1: (1/(1 + exp(-X))
if (HopRewriteUtils.isBinary(uopin, OpOp2.MINUS)) {
BinaryOp bop3 = (BinaryOp) uopin;
Hop left3 = bop3.getInput().get(0);
Hop right3 = bop3.getInput().get(1);
if (left3 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left3) == 0)
unary = HopRewriteUtils.createUnary(right3, OpOp1.SIGMOID);
} else // Pattern 2: (1/(1 + exp(X)), e.g., where -(-X) has been removed by
// the 'remove unnecessary minus' rewrite --> reintroduce the minus
{
BinaryOp minus = HopRewriteUtils.createBinaryMinus(uopin);
unary = HopRewriteUtils.createUnary(minus, OpOp1.SIGMOID);
}
if (unary != null) {
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, bop2, uop);
hi = unary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sigmoid1");
}
}
}
}
// select positive (selp) operator (note: same initial pattern as sprop)
if (!applied && bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX) {
// to replace the X*tmp with selp(X) due to lower memory requirements and simply sparsity propagation
if (// (X>0)*X
left instanceof BinaryOp) {
BinaryOp bleft = (BinaryOp) left;
Hop left1 = bleft.getInput().get(0);
Hop left2 = bleft.getInput().get(1);
if (left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) left2) == 0 && left1 == right && (bleft.getOp() == OpOp2.GREATER)) {
BinaryOp binary = HopRewriteUtils.createBinary(right, new LiteralOp(0), OpOp2.MAX);
HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = binary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0a");
}
}
if (// X*(X>0)
!applied && right instanceof BinaryOp) {
BinaryOp bright = (BinaryOp) right;
Hop right1 = bright.getInput().get(0);
Hop right2 = bright.getInput().get(1);
if (right2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) right2) == 0 && right1 == left && bright.getOp() == OpOp2.GREATER) {
BinaryOp binary = HopRewriteUtils.createBinary(left, new LiteralOp(0), OpOp2.MAX);
HopRewriteUtils.replaceChildReference(parent, bop, binary, pos);
HopRewriteUtils.cleanupUnreferenced(bop, left);
hi = binary;
applied = true;
LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0b");
}
}
}
}
return hi;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifyMultiBinaryToBinaryOperation.
private static Hop simplifyMultiBinaryToBinaryOperation(Hop hi) {
// pattern: 1-(X*Y) --> X 1-* Y (avoid intermediate)
if (HopRewriteUtils.isBinary(hi, OpOp2.MINUS) && hi.getDataType() == DataType.MATRIX && hi.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) hi.getInput().get(0)) == 1 && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) && // single consumer
hi.getInput().get(1).getParent().size() == 1) {
BinaryOp bop = (BinaryOp) hi;
Hop left = hi.getInput().get(1).getInput().get(0);
Hop right = hi.getInput().get(1).getInput().get(1);
// set new binaryop type and rewire inputs
bop.setOp(OpOp2.MINUS1_MULT);
HopRewriteUtils.removeAllChildReferences(hi);
HopRewriteUtils.addChildReference(bop, left);
HopRewriteUtils.addChildReference(bop, right);
LOG.debug("Applied simplifyMultiBinaryToBinaryOperation.");
}
return hi;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method fuseOrderOperationChain.
private static Hop fuseOrderOperationChain(Hop hi) {
// order(order(X,2),1) -> order(X, (12)),
if (HopRewriteUtils.isReorg(hi, ReOrgOp.SORT) && // scalar by
hi.getInput().get(1) instanceof LiteralOp && // scalar desc
hi.getInput().get(2) instanceof LiteralOp && // not ixret
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false) && !OptimizerUtils.isHadoopExecutionMode()) {
LiteralOp by = (LiteralOp) hi.getInput().get(1);
boolean desc = HopRewriteUtils.getBooleanValue((LiteralOp) hi.getInput().get(2));
// find chain of order operations with same desc/ixret configuration and single consumers
ArrayList<LiteralOp> byList = new ArrayList<LiteralOp>();
byList.add(by);
Hop input = hi.getInput().get(0);
while (HopRewriteUtils.isReorg(input, ReOrgOp.SORT) && // scalar by
input.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.isLiteralOfValue(input.getInput().get(2), desc) && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false) && input.getParent().size() == 1) {
byList.add((LiteralOp) input.getInput().get(1));
input = input.getInput().get(0);
}
// merge order chain if at least two instances
if (byList.size() >= 2) {
// create new order operations
ArrayList<Hop> inputs = new ArrayList<>();
inputs.add(input);
inputs.add(HopRewriteUtils.createDataGenOpByVal(byList, 1, byList.size()));
inputs.add(new LiteralOp(desc));
inputs.add(new LiteralOp(false));
Hop hnew = HopRewriteUtils.createReorg(inputs, ReOrgOp.SORT);
// cleanup references recursively
Hop current = hi;
while (current != input) {
Hop tmp = current.getInput().get(0);
HopRewriteUtils.removeAllChildReferences(current);
current = tmp;
}
// rewire all parents (avoid anomalies with replicated datagen)
List<Hop> parents = new ArrayList<>(hi.getParent());
for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, hi, hnew);
hi = hnew;
LOG.debug("Applied fuseOrderOperationChain (line " + hi.getBeginLine() + ").");
}
}
return hi;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class DMLTranslator method processIndexingExpression.
private Hop processIndexingExpression(IndexedIdentifier source, DataIdentifier target, HashMap<String, Hop> hops) {
// process Hops for indexes (for source)
Hop rowLowerHops = null, rowUpperHops = null, colLowerHops = null, colUpperHops = null;
if (source.getRowLowerBound() != null)
rowLowerHops = processExpression(source.getRowLowerBound(), null, hops);
else
rowLowerHops = new LiteralOp(1);
if (source.getRowUpperBound() != null)
rowUpperHops = processExpression(source.getRowUpperBound(), null, hops);
else {
if (source.getOrigDim1() != -1)
rowUpperHops = new LiteralOp(source.getOrigDim1());
else {
rowUpperHops = new UnaryOp(source.getName(), DataType.SCALAR, ValueType.INT, Hop.OpOp1.NROW, hops.get(source.getName()));
rowUpperHops.setParseInfo(source);
}
}
if (source.getColLowerBound() != null)
colLowerHops = processExpression(source.getColLowerBound(), null, hops);
else
colLowerHops = new LiteralOp(1);
if (source.getColUpperBound() != null)
colUpperHops = processExpression(source.getColUpperBound(), null, hops);
else {
if (source.getOrigDim2() != -1)
colUpperHops = new LiteralOp(source.getOrigDim2());
else
colUpperHops = new UnaryOp(source.getName(), DataType.SCALAR, ValueType.INT, Hop.OpOp1.NCOL, hops.get(source.getName()));
}
if (target == null) {
target = createTarget(source);
}
// unknown nnz after range indexing (applies to indexing op but also
// data dependent operations)
target.setNnz(-1);
Hop indexOp = new IndexingOp(target.getName(), target.getDataType(), target.getValueType(), hops.get(source.getName()), rowLowerHops, rowUpperHops, colLowerHops, colUpperHops, source.getRowLowerEqualsUpper(), source.getColLowerEqualsUpper());
indexOp.setParseInfo(target);
setIdentifierParams(indexOp, target);
return indexOp;
}
Aggregations