use of org.apache.sysml.hops.DataGenOp in project systemml by apache.
the class RewriteAlgebraicSimplificationStatic method removeUnnecessaryVectorizeOperation.
private static Hop removeUnnecessaryVectorizeOperation(Hop hi) {
// applies to all binary matrix operations, if one input is unnecessarily vectorized
if (hi instanceof BinaryOp && hi.getDataType() == DataType.MATRIX && ((BinaryOp) hi).supportsMatrixScalarOperations()) {
BinaryOp bop = (BinaryOp) hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
if (// no outer
!(left.getDim1() > 1 && left.getDim2() == 1 && right.getDim1() == 1 && right.getDim2() > 1)) {
// check and remove right vectorized scalar
if (left.getDataType() == DataType.MATRIX && right instanceof DataGenOp) {
DataGenOp dright = (DataGenOp) right;
if (dright.getOp() == DataGenMethod.RAND && dright.hasConstantValue()) {
Hop drightIn = dright.getInput().get(dright.getParamIndex(DataExpression.RAND_MIN));
HopRewriteUtils.replaceChildReference(bop, dright, drightIn, 1);
HopRewriteUtils.cleanupUnreferenced(dright);
LOG.debug("Applied removeUnnecessaryVectorizeOperation1");
}
} else // check and remove left vectorized scalar
if (right.getDataType() == DataType.MATRIX && left instanceof DataGenOp) {
DataGenOp dleft = (DataGenOp) left;
if (dleft.getOp() == DataGenMethod.RAND && dleft.hasConstantValue() && (left.getDim2() == 1 || right.getDim2() > 1) && (left.getDim1() == 1 || right.getDim1() > 1)) {
Hop dleftIn = dleft.getInput().get(dleft.getParamIndex(DataExpression.RAND_MIN));
HopRewriteUtils.replaceChildReference(bop, dleft, dleftIn, 0);
HopRewriteUtils.cleanupUnreferenced(dleft);
LOG.debug("Applied removeUnnecessaryVectorizeOperation2");
}
}
// Note: we applied this rewrite to at most one side in order to keep the
// output semantically equivalent. However, future extensions might consider
// to remove vectors from both side, compute the binary op on scalars and
// finally feed it into a datagenop of the original dimensions.
}
}
return hi;
}
use of org.apache.sysml.hops.DataGenOp in project systemml by apache.
the class RewriteAlgebraicSimplificationStatic method fuseDatagenAndBinaryOperation.
/**
* Handle removal of unnecessary binary operations over rand data
*
* rand*7 -> rand(min*7,max*7); rand+7 -> rand(min+7,max+7); rand-7 -> rand(min+(-7),max+(-7))
* 7*rand -> rand(min*7,max*7); 7+rand -> rand(min+7,max+7);
*
* @param hi high-order operation
* @return high-level operator
*/
@SuppressWarnings("incomplete-switch")
private static Hop fuseDatagenAndBinaryOperation(Hop hi) {
if (hi instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
// left input rand and hence output matrix double, right scalar literal
if (HopRewriteUtils.isDataGenOp(left, DataGenMethod.RAND) && right instanceof LiteralOp && left.getParent().size() == 1) {
DataGenOp inputGen = (DataGenOp) left;
Hop pdf = inputGen.getInput(DataExpression.RAND_PDF);
Hop min = inputGen.getInput(DataExpression.RAND_MIN);
Hop max = inputGen.getInput(DataExpression.RAND_MAX);
double sval = ((LiteralOp) right).getDoubleValue();
boolean pdfUniform = pdf instanceof LiteralOp && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp) pdf).getStringValue());
if (HopRewriteUtils.isBinary(bop, OpOp2.MULT, OpOp2.PLUS, OpOp2.MINUS, OpOp2.DIV) && min instanceof LiteralOp && max instanceof LiteralOp && pdfUniform) {
// create fused data gen operator
DataGenOp gen = null;
switch(// fuse via scale and shift
bop.getOp()) {
case MULT:
gen = HopRewriteUtils.copyDataGenOp(inputGen, sval, 0);
break;
case PLUS:
case MINUS:
gen = HopRewriteUtils.copyDataGenOp(inputGen, 1, sval * ((bop.getOp() == OpOp2.MINUS) ? -1 : 1));
break;
case DIV:
gen = HopRewriteUtils.copyDataGenOp(inputGen, 1 / sval, 0);
break;
}
// rewire all parents (avoid anomalies with replicated datagen)
List<Hop> parents = new ArrayList<>(bop.getParent());
for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, bop, gen);
hi = gen;
LOG.debug("Applied fuseDatagenAndBinaryOperation1 " + "(" + bop.getFilename() + ", line " + bop.getBeginLine() + ").");
}
} else // right input rand and hence output matrix double, left scalar literal
if (right instanceof DataGenOp && ((DataGenOp) right).getOp() == DataGenMethod.RAND && left instanceof LiteralOp && right.getParent().size() == 1) {
DataGenOp inputGen = (DataGenOp) right;
Hop pdf = inputGen.getInput(DataExpression.RAND_PDF);
Hop min = inputGen.getInput(DataExpression.RAND_MIN);
Hop max = inputGen.getInput(DataExpression.RAND_MAX);
double sval = ((LiteralOp) left).getDoubleValue();
boolean pdfUniform = pdf instanceof LiteralOp && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp) pdf).getStringValue());
if ((bop.getOp() == OpOp2.MULT || bop.getOp() == OpOp2.PLUS) && min instanceof LiteralOp && max instanceof LiteralOp && pdfUniform) {
// create fused data gen operator
DataGenOp gen = null;
if (bop.getOp() == OpOp2.MULT)
gen = HopRewriteUtils.copyDataGenOp(inputGen, sval, 0);
else {
// OpOp2.PLUS
gen = HopRewriteUtils.copyDataGenOp(inputGen, 1, sval);
}
// rewire all parents (avoid anomalies with replicated datagen)
List<Hop> parents = new ArrayList<>(bop.getParent());
for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, bop, gen);
hi = gen;
LOG.debug("Applied fuseDatagenAndBinaryOperation2 " + "(" + bop.getFilename() + ", line " + bop.getBeginLine() + ").");
}
} else // left input rand and hence output matrix double, right scalar variable
if (HopRewriteUtils.isDataGenOp(left, DataGenMethod.RAND) && right.getDataType().isScalar() && left.getParent().size() == 1) {
DataGenOp gen = (DataGenOp) left;
Hop min = gen.getInput(DataExpression.RAND_MIN);
Hop max = gen.getInput(DataExpression.RAND_MAX);
Hop pdf = gen.getInput(DataExpression.RAND_PDF);
boolean pdfUniform = pdf instanceof LiteralOp && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp) pdf).getStringValue());
if (HopRewriteUtils.isBinary(bop, OpOp2.PLUS) && HopRewriteUtils.isLiteralOfValue(min, 0) && HopRewriteUtils.isLiteralOfValue(max, 0)) {
gen.setInput(DataExpression.RAND_MIN, right);
gen.setInput(DataExpression.RAND_MAX, right);
// rewire all parents (avoid anomalies with replicated datagen)
List<Hop> parents = new ArrayList<>(bop.getParent());
for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, bop, gen);
hi = gen;
LOG.debug("Applied fuseDatagenAndBinaryOperation3a " + "(" + bop.getFilename() + ", line " + bop.getBeginLine() + ").");
} else if (HopRewriteUtils.isBinary(bop, OpOp2.MULT) && ((HopRewriteUtils.isLiteralOfValue(min, 0) && pdfUniform) || HopRewriteUtils.isLiteralOfValue(min, 1)) && HopRewriteUtils.isLiteralOfValue(max, 1)) {
if (HopRewriteUtils.isLiteralOfValue(min, 1))
gen.setInput(DataExpression.RAND_MIN, right);
gen.setInput(DataExpression.RAND_MAX, right);
// rewire all parents (avoid anomalies with replicated datagen)
List<Hop> parents = new ArrayList<>(bop.getParent());
for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, bop, gen);
hi = gen;
LOG.debug("Applied fuseDatagenAndBinaryOperation3b " + "(" + bop.getFilename() + ", line " + bop.getBeginLine() + ").");
}
}
}
return hi;
}
use of org.apache.sysml.hops.DataGenOp in project systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifyCTableWithConstMatrixInputs.
private static Hop simplifyCTableWithConstMatrixInputs(Hop hi) {
// pattern: table(X, matrix(1,...), matrix(7, ...)) -> table(X, 1, 7)
if (HopRewriteUtils.isTernary(hi, OpOp3.CTABLE)) {
// note: the first input always expected to be a matrix
for (int i = 1; i < hi.getInput().size(); i++) {
Hop inCurr = hi.getInput().get(i);
if (HopRewriteUtils.isDataGenOpWithConstantValue(inCurr)) {
Hop inNew = ((DataGenOp) inCurr).getInput(DataExpression.RAND_MIN);
HopRewriteUtils.replaceChildReference(hi, inCurr, inNew, i);
LOG.debug("Applied simplifyCTableWithConstMatrixInputs" + i + " (line " + hi.getBeginLine() + ").");
}
}
}
return hi;
}
use of org.apache.sysml.hops.DataGenOp in project systemml by apache.
the class RewriteAlgebraicSimplificationStatic method fuseDatagenAndMinusOperation.
private static Hop fuseDatagenAndMinusOperation(Hop hi) {
if (hi instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
if (right instanceof DataGenOp && ((DataGenOp) right).getOp() == DataGenMethod.RAND && left instanceof LiteralOp && ((LiteralOp) left).getDoubleValue() == 0.0) {
DataGenOp inputGen = (DataGenOp) right;
HashMap<String, Integer> params = inputGen.getParamIndexMap();
Hop pdf = right.getInput().get(params.get(DataExpression.RAND_PDF));
int ixMin = params.get(DataExpression.RAND_MIN);
int ixMax = params.get(DataExpression.RAND_MAX);
Hop min = right.getInput().get(ixMin);
Hop max = right.getInput().get(ixMax);
// apply rewrite under additional conditions (for simplicity)
if (inputGen.getParent().size() == 1 && min instanceof LiteralOp && max instanceof LiteralOp && pdf instanceof LiteralOp && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp) pdf).getStringValue())) {
// exchange and *-1 (special case 0 stays 0 instead of -0 for consistency)
double newMinVal = (((LiteralOp) max).getDoubleValue() == 0) ? 0 : (-1 * ((LiteralOp) max).getDoubleValue());
double newMaxVal = (((LiteralOp) min).getDoubleValue() == 0) ? 0 : (-1 * ((LiteralOp) min).getDoubleValue());
Hop newMin = new LiteralOp(newMinVal);
Hop newMax = new LiteralOp(newMaxVal);
HopRewriteUtils.removeChildReferenceByPos(inputGen, min, ixMin);
HopRewriteUtils.addChildReference(inputGen, newMin, ixMin);
HopRewriteUtils.removeChildReferenceByPos(inputGen, max, ixMax);
HopRewriteUtils.addChildReference(inputGen, newMax, ixMax);
// rewire all parents (avoid anomalies with replicated datagen)
List<Hop> parents = new ArrayList<>(bop.getParent());
for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, bop, inputGen);
hi = inputGen;
LOG.debug("Applied fuseDatagenAndMinusOperation (line " + bop.getBeginLine() + ").");
}
}
}
return hi;
}
use of org.apache.sysml.hops.DataGenOp in project systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifyOrderedSort.
private static Hop simplifyOrderedSort(Hop parent, Hop hi, int pos) {
// order(seq(2,N+1,1), indexreturn=TRUE) -> seq(1,N,1)/seq(N,1,-1)
if (// order
hi instanceof ReorgOp && ((ReorgOp) hi).getOp() == ReOrgOp.SORT) {
Hop hi2 = hi.getInput().get(0);
if (hi2 instanceof DataGenOp && ((DataGenOp) hi2).getOp() == DataGenMethod.SEQ) {
Hop incr = hi2.getInput().get(((DataGenOp) hi2).getParamIndex(Statement.SEQ_INCR));
// check for known ascending ordering and known indexreturn
if (incr instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) incr) == 1 && // decreasing
hi.getInput().get(2) instanceof LiteralOp && // indexreturn
hi.getInput().get(3) instanceof LiteralOp) {
if (// IXRET, ASC/DESC
HopRewriteUtils.getBooleanValue((LiteralOp) hi.getInput().get(3))) {
// order(seq(2,N+1,1), indexreturn=TRUE) -> seq(1,N,1)/seq(N,1,-1)
boolean desc = HopRewriteUtils.getBooleanValue((LiteralOp) hi.getInput().get(2));
Hop seq = HopRewriteUtils.createSeqDataGenOp(hi2, !desc);
seq.refreshSizeInformation();
HopRewriteUtils.replaceChildReference(parent, hi, seq, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = seq;
LOG.debug("Applied simplifyOrderedSort1.");
} else if (// DATA, ASC
!HopRewriteUtils.getBooleanValue((LiteralOp) hi.getInput().get(2))) {
// order(seq(2,N+1,1), indexreturn=FALSE) -> seq(2,N+1,1)
HopRewriteUtils.replaceChildReference(parent, hi, hi2, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = hi2;
LOG.debug("Applied simplifyOrderedSort2.");
}
}
}
}
return hi;
}
Aggregations