use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifyTraceMatrixMult.
private static Hop simplifyTraceMatrixMult(Hop parent, Hop hi, int pos) {
if (// trace()
hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.TRACE) {
Hop hi2 = hi.getInput().get(0);
if (// X%*%Y
HopRewriteUtils.isMatrixMultiply(hi2)) {
Hop left = hi2.getInput().get(0);
Hop right = hi2.getInput().get(1);
// create new operators (incl refresh size inside for transpose)
ReorgOp trans = HopRewriteUtils.createTranspose(right);
BinaryOp mult = HopRewriteUtils.createBinary(left, trans, OpOp2.MULT);
AggUnaryOp sum = HopRewriteUtils.createSum(mult);
// rehang new subdag under parent node
HopRewriteUtils.replaceChildReference(parent, hi, sum, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
hi = sum;
LOG.debug("Applied simplifyTraceMatrixMult");
}
}
return hi;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method fuseDatagenAndBinaryOperation.
/**
* Handle removal of unnecessary binary operations over rand data
*
* rand*7 -> rand(min*7,max*7); rand+7 -> rand(min+7,max+7); rand-7 -> rand(min+(-7),max+(-7))
* 7*rand -> rand(min*7,max*7); 7+rand -> rand(min+7,max+7);
*
* @param hi high-order operation
* @return high-level operator
*/
@SuppressWarnings("incomplete-switch")
private static Hop fuseDatagenAndBinaryOperation(Hop hi) {
if (hi instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
// left input rand and hence output matrix double, right scalar literal
if (HopRewriteUtils.isDataGenOp(left, DataGenMethod.RAND) && right instanceof LiteralOp && left.getParent().size() == 1) {
DataGenOp inputGen = (DataGenOp) left;
Hop pdf = inputGen.getInput(DataExpression.RAND_PDF);
Hop min = inputGen.getInput(DataExpression.RAND_MIN);
Hop max = inputGen.getInput(DataExpression.RAND_MAX);
double sval = ((LiteralOp) right).getDoubleValue();
boolean pdfUniform = pdf instanceof LiteralOp && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp) pdf).getStringValue());
if (HopRewriteUtils.isBinary(bop, OpOp2.MULT, OpOp2.PLUS, OpOp2.MINUS, OpOp2.DIV) && min instanceof LiteralOp && max instanceof LiteralOp && pdfUniform) {
// create fused data gen operator
DataGenOp gen = null;
switch(// fuse via scale and shift
bop.getOp()) {
case MULT:
gen = HopRewriteUtils.copyDataGenOp(inputGen, sval, 0);
break;
case PLUS:
case MINUS:
gen = HopRewriteUtils.copyDataGenOp(inputGen, 1, sval * ((bop.getOp() == OpOp2.MINUS) ? -1 : 1));
break;
case DIV:
gen = HopRewriteUtils.copyDataGenOp(inputGen, 1 / sval, 0);
break;
}
// rewire all parents (avoid anomalies with replicated datagen)
List<Hop> parents = new ArrayList<>(bop.getParent());
for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, bop, gen);
hi = gen;
LOG.debug("Applied fuseDatagenAndBinaryOperation1 " + "(" + bop.getFilename() + ", line " + bop.getBeginLine() + ").");
}
} else // right input rand and hence output matrix double, left scalar literal
if (right instanceof DataGenOp && ((DataGenOp) right).getOp() == DataGenMethod.RAND && left instanceof LiteralOp && right.getParent().size() == 1) {
DataGenOp inputGen = (DataGenOp) right;
Hop pdf = inputGen.getInput(DataExpression.RAND_PDF);
Hop min = inputGen.getInput(DataExpression.RAND_MIN);
Hop max = inputGen.getInput(DataExpression.RAND_MAX);
double sval = ((LiteralOp) left).getDoubleValue();
boolean pdfUniform = pdf instanceof LiteralOp && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp) pdf).getStringValue());
if ((bop.getOp() == OpOp2.MULT || bop.getOp() == OpOp2.PLUS) && min instanceof LiteralOp && max instanceof LiteralOp && pdfUniform) {
// create fused data gen operator
DataGenOp gen = null;
if (bop.getOp() == OpOp2.MULT)
gen = HopRewriteUtils.copyDataGenOp(inputGen, sval, 0);
else {
// OpOp2.PLUS
gen = HopRewriteUtils.copyDataGenOp(inputGen, 1, sval);
}
// rewire all parents (avoid anomalies with replicated datagen)
List<Hop> parents = new ArrayList<>(bop.getParent());
for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, bop, gen);
hi = gen;
LOG.debug("Applied fuseDatagenAndBinaryOperation2 " + "(" + bop.getFilename() + ", line " + bop.getBeginLine() + ").");
}
} else // left input rand and hence output matrix double, right scalar variable
if (HopRewriteUtils.isDataGenOp(left, DataGenMethod.RAND) && right.getDataType().isScalar() && left.getParent().size() == 1) {
DataGenOp gen = (DataGenOp) left;
Hop min = gen.getInput(DataExpression.RAND_MIN);
Hop max = gen.getInput(DataExpression.RAND_MAX);
Hop pdf = gen.getInput(DataExpression.RAND_PDF);
boolean pdfUniform = pdf instanceof LiteralOp && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp) pdf).getStringValue());
if (HopRewriteUtils.isBinary(bop, OpOp2.PLUS) && HopRewriteUtils.isLiteralOfValue(min, 0) && HopRewriteUtils.isLiteralOfValue(max, 0)) {
gen.setInput(DataExpression.RAND_MIN, right);
gen.setInput(DataExpression.RAND_MAX, right);
// rewire all parents (avoid anomalies with replicated datagen)
List<Hop> parents = new ArrayList<>(bop.getParent());
for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, bop, gen);
hi = gen;
LOG.debug("Applied fuseDatagenAndBinaryOperation3a " + "(" + bop.getFilename() + ", line " + bop.getBeginLine() + ").");
} else if (HopRewriteUtils.isBinary(bop, OpOp2.MULT) && ((HopRewriteUtils.isLiteralOfValue(min, 0) && pdfUniform) || HopRewriteUtils.isLiteralOfValue(min, 1)) && HopRewriteUtils.isLiteralOfValue(max, 1)) {
if (HopRewriteUtils.isLiteralOfValue(min, 1))
gen.setInput(DataExpression.RAND_MIN, right);
gen.setInput(DataExpression.RAND_MAX, right);
// rewire all parents (avoid anomalies with replicated datagen)
List<Hop> parents = new ArrayList<>(bop.getParent());
for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, bop, gen);
hi = gen;
LOG.debug("Applied fuseDatagenAndBinaryOperation3b " + "(" + bop.getFilename() + ", line " + bop.getBeginLine() + ").");
}
}
}
return hi;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method removeUnnecessaryBinaryOperation.
/**
* handle removal of unnecessary binary operations
*
* X/1 or X*1 or 1*X or X-0 -> X
* -1*X or X*-1-> -X
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
private static Hop removeUnnecessaryBinaryOperation(Hop parent, Hop hi, int pos) {
if (hi instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
// X/1 or X*1 -> X
if (left.getDataType() == DataType.MATRIX && right instanceof LiteralOp && ((LiteralOp) right).getDoubleValue() == 1.0) {
if (bop.getOp() == OpOp2.DIV || bop.getOp() == OpOp2.MULT) {
HopRewriteUtils.replaceChildReference(parent, bop, left, pos);
hi = left;
LOG.debug("Applied removeUnnecessaryBinaryOperation1 (line " + bop.getBeginLine() + ")");
}
} else // X-0 -> X
if (left.getDataType() == DataType.MATRIX && right instanceof LiteralOp && ((LiteralOp) right).getDoubleValue() == 0.0) {
if (bop.getOp() == OpOp2.MINUS) {
HopRewriteUtils.replaceChildReference(parent, bop, left, pos);
hi = left;
LOG.debug("Applied removeUnnecessaryBinaryOperation2 (line " + bop.getBeginLine() + ")");
}
} else // 1*X -> X
if (right.getDataType() == DataType.MATRIX && left instanceof LiteralOp && ((LiteralOp) left).getDoubleValue() == 1.0) {
if (bop.getOp() == OpOp2.MULT) {
HopRewriteUtils.replaceChildReference(parent, bop, right, pos);
hi = right;
LOG.debug("Applied removeUnnecessaryBinaryOperation3 (line " + bop.getBeginLine() + ")");
}
} else // -X to -1*X due to mechanical reasons
if (right.getDataType() == DataType.MATRIX && left instanceof LiteralOp && ((LiteralOp) left).getDoubleValue() == -1.0) {
if (bop.getOp() == OpOp2.MULT) {
bop.setOp(OpOp2.MINUS);
HopRewriteUtils.replaceChildReference(bop, left, new LiteralOp(0), 0);
hi = bop;
LOG.debug("Applied removeUnnecessaryBinaryOperation4 (line " + bop.getBeginLine() + ")");
}
} else // X*-1 -> -X (see comment above)
if (left.getDataType() == DataType.MATRIX && right instanceof LiteralOp && ((LiteralOp) right).getDoubleValue() == -1.0) {
if (bop.getOp() == OpOp2.MULT) {
bop.setOp(OpOp2.MINUS);
HopRewriteUtils.removeChildReferenceByPos(bop, right, 1);
HopRewriteUtils.addChildReference(bop, new LiteralOp(0), 0);
hi = bop;
LOG.debug("Applied removeUnnecessaryBinaryOperation5 (line " + bop.getBeginLine() + ")");
}
}
}
return hi;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method canonicalizeMatrixMultScalarAdd.
/**
* Rewrite to canonicalize all patterns like U%*%V+eps, eps+U%*%V, and
* U%*%V-eps into the common representation U%*%V+s which simplifies
* subsequent rewrites (e.g., wdivmm or wcemm with epsilon).
*
* @param hi high-level operator
* @return high-level operator
*/
private static Hop canonicalizeMatrixMultScalarAdd(Hop hi) {
// pattern: binary operation (+ or -) of matrix mult and scalar
if (hi instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hi;
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
// pattern: (eps + U%*%V) -> (U%*%V+eps)
if (left.getDataType().isScalar() && right instanceof AggBinaryOp && bop.getOp() == OpOp2.PLUS) {
HopRewriteUtils.removeAllChildReferences(bop);
HopRewriteUtils.addChildReference(bop, right, 0);
HopRewriteUtils.addChildReference(bop, left, 1);
LOG.debug("Applied canonicalizeMatrixMultScalarAdd1 (line " + hi.getBeginLine() + ").");
} else // pattern: (U%*%V - eps) -> (U%*%V + (-eps))
if (right.getDataType().isScalar() && left instanceof AggBinaryOp && bop.getOp() == OpOp2.MINUS) {
bop.setOp(OpOp2.PLUS);
HopRewriteUtils.replaceChildReference(bop, right, HopRewriteUtils.createBinaryMinus(right), 1);
LOG.debug("Applied canonicalizeMatrixMultScalarAdd2 (line " + hi.getBeginLine() + ").");
}
}
return hi;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifyBinaryMatrixScalarOperation.
private static Hop simplifyBinaryMatrixScalarOperation(Hop parent, Hop hi, int pos) {
if (HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR) && hi.getInput().get(0) instanceof BinaryOp && HopRewriteUtils.isBinary(hi.getInput().get(0), LOOKUP_VALID_SCALAR_BINARY)) {
BinaryOp bin = (BinaryOp) hi.getInput().get(0);
BinaryOp bout = null;
// as.scalar(X*Y) -> as.scalar(X) * as.scalar(Y)
if (bin.getInput().get(0).getDataType() == DataType.MATRIX && bin.getInput().get(1).getDataType() == DataType.MATRIX) {
UnaryOp cast1 = HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR);
UnaryOp cast2 = HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR);
bout = HopRewriteUtils.createBinary(cast1, cast2, bin.getOp());
} else // as.scalar(X*s) -> as.scalar(X) * s
if (bin.getInput().get(0).getDataType() == DataType.MATRIX) {
UnaryOp cast = HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR);
bout = HopRewriteUtils.createBinary(cast, bin.getInput().get(1), bin.getOp());
} else // as.scalar(s*X) -> s * as.scalar(X)
if (bin.getInput().get(1).getDataType() == DataType.MATRIX) {
UnaryOp cast = HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR);
bout = HopRewriteUtils.createBinary(bin.getInput().get(0), cast, bin.getOp());
}
if (bout != null) {
HopRewriteUtils.replaceChildReference(parent, hi, bout, pos);
LOG.debug("Applied simplifyBinaryMatrixScalarOperation.");
}
}
return hi;
}
Aggregations