use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifyOuterSeqExpand.
private static Hop simplifyOuterSeqExpand(Hop parent, Hop hi, int pos) {
if (HopRewriteUtils.isBinary(hi, OpOp2.EQUAL) && ((BinaryOp) hi).isOuterVectorOperator()) {
if ((// pattern a: outer(v, t(seq(1,m)), "==")
HopRewriteUtils.isTransposeOperation(hi.getInput().get(1)) && HopRewriteUtils.isBasic1NSequence(hi.getInput().get(1).getInput().get(0))) || // pattern b: outer(seq(1,m), t(v) "==")
HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0))) {
// determine variable parameters for pattern a/b
boolean isPatternB = HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0));
boolean isTransposeRight = HopRewriteUtils.isTransposeOperation(hi.getInput().get(1));
Hop trgt = isPatternB ? (isTransposeRight ? // get v from t(v)
hi.getInput().get(1).getInput().get(0) : // create v via t(v')
HopRewriteUtils.createTranspose(hi.getInput().get(1))) : // get v directly
hi.getInput().get(0);
Hop seq = isPatternB ? hi.getInput().get(0) : hi.getInput().get(1).getInput().get(0);
String direction = HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)) ? "rows" : "cols";
// setup input parameter hops
HashMap<String, Hop> inputargs = new HashMap<>();
inputargs.put("target", trgt);
inputargs.put("max", HopRewriteUtils.getBasic1NSequenceMax(seq));
inputargs.put("dir", new LiteralOp(direction));
inputargs.put("ignore", new LiteralOp(true));
inputargs.put("cast", new LiteralOp(false));
// create new hop
ParameterizedBuiltinOp pbop = HopRewriteUtils.createParameterizedBuiltinOp(trgt, inputargs, ParamBuiltinOp.REXPAND);
// relink new hop into original position
HopRewriteUtils.replaceChildReference(parent, hi, pbop, pos);
hi = pbop;
LOG.debug("Applied simplifyOuterSeqExpand (line " + hi.getBeginLine() + ")");
}
}
return hi;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifySlicedMatrixMult.
private static Hop simplifySlicedMatrixMult(Hop parent, Hop hi, int pos) {
// e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1]
if (hi instanceof IndexingOp && ((IndexingOp) hi).isRowLowerEqualsUpper() && ((IndexingOp) hi).isColLowerEqualsUpper() && // rix is single mm consumer
hi.getInput().get(0).getParent().size() == 1 && HopRewriteUtils.isMatrixMultiply(hi.getInput().get(0))) {
Hop mm = hi.getInput().get(0);
Hop X = mm.getInput().get(0);
Hop Y = mm.getInput().get(1);
// rl==ru
Hop rowExpr = hi.getInput().get(1);
// cl==cu
Hop colExpr = hi.getInput().get(3);
HopRewriteUtils.removeAllChildReferences(mm);
// create new indexing operations
IndexingOp ix1 = new IndexingOp("tmp1", DataType.MATRIX, ValueType.DOUBLE, X, rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(X, false), true, false);
ix1.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
ix1.refreshSizeInformation();
IndexingOp ix2 = new IndexingOp("tmp2", DataType.MATRIX, ValueType.DOUBLE, Y, new LiteralOp(1), HopRewriteUtils.createValueHop(Y, true), colExpr, colExpr, false, true);
ix2.setOutputBlocksizes(Y.getRowsInBlock(), Y.getColsInBlock());
ix2.refreshSizeInformation();
// rewire matrix mult over ix1 and ix2
HopRewriteUtils.addChildReference(mm, ix1, 0);
HopRewriteUtils.addChildReference(mm, ix2, 1);
mm.refreshSizeInformation();
hi = mm;
LOG.debug("Applied simplifySlicedMatrixMult");
}
return hi;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method fuseDatagenAndBinaryOperation.
/**
* Handle removal of unnecessary binary operations over rand data
*
* rand*7 -> rand(min*7,max*7); rand+7 -> rand(min+7,max+7); rand-7 -> rand(min+(-7),max+(-7))
* 7*rand -> rand(min*7,max*7); 7+rand -> rand(min+7,max+7);
*
* @param hi high-order operation
* @return high-level operator
*/
@SuppressWarnings("incomplete-switch")
private static Hop fuseDatagenAndBinaryOperation(Hop hi) {
if (hi instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
// left input rand and hence output matrix double, right scalar literal
if (HopRewriteUtils.isDataGenOp(left, DataGenMethod.RAND) && right instanceof LiteralOp && left.getParent().size() == 1) {
DataGenOp inputGen = (DataGenOp) left;
Hop pdf = inputGen.getInput(DataExpression.RAND_PDF);
Hop min = inputGen.getInput(DataExpression.RAND_MIN);
Hop max = inputGen.getInput(DataExpression.RAND_MAX);
double sval = ((LiteralOp) right).getDoubleValue();
boolean pdfUniform = pdf instanceof LiteralOp && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp) pdf).getStringValue());
if (HopRewriteUtils.isBinary(bop, OpOp2.MULT, OpOp2.PLUS, OpOp2.MINUS, OpOp2.DIV) && min instanceof LiteralOp && max instanceof LiteralOp && pdfUniform) {
// create fused data gen operator
DataGenOp gen = null;
switch(// fuse via scale and shift
bop.getOp()) {
case MULT:
gen = HopRewriteUtils.copyDataGenOp(inputGen, sval, 0);
break;
case PLUS:
case MINUS:
gen = HopRewriteUtils.copyDataGenOp(inputGen, 1, sval * ((bop.getOp() == OpOp2.MINUS) ? -1 : 1));
break;
case DIV:
gen = HopRewriteUtils.copyDataGenOp(inputGen, 1 / sval, 0);
break;
}
// rewire all parents (avoid anomalies with replicated datagen)
List<Hop> parents = new ArrayList<>(bop.getParent());
for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, bop, gen);
hi = gen;
LOG.debug("Applied fuseDatagenAndBinaryOperation1 " + "(" + bop.getFilename() + ", line " + bop.getBeginLine() + ").");
}
} else // right input rand and hence output matrix double, left scalar literal
if (right instanceof DataGenOp && ((DataGenOp) right).getOp() == DataGenMethod.RAND && left instanceof LiteralOp && right.getParent().size() == 1) {
DataGenOp inputGen = (DataGenOp) right;
Hop pdf = inputGen.getInput(DataExpression.RAND_PDF);
Hop min = inputGen.getInput(DataExpression.RAND_MIN);
Hop max = inputGen.getInput(DataExpression.RAND_MAX);
double sval = ((LiteralOp) left).getDoubleValue();
boolean pdfUniform = pdf instanceof LiteralOp && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp) pdf).getStringValue());
if ((bop.getOp() == OpOp2.MULT || bop.getOp() == OpOp2.PLUS) && min instanceof LiteralOp && max instanceof LiteralOp && pdfUniform) {
// create fused data gen operator
DataGenOp gen = null;
if (bop.getOp() == OpOp2.MULT)
gen = HopRewriteUtils.copyDataGenOp(inputGen, sval, 0);
else {
// OpOp2.PLUS
gen = HopRewriteUtils.copyDataGenOp(inputGen, 1, sval);
}
// rewire all parents (avoid anomalies with replicated datagen)
List<Hop> parents = new ArrayList<>(bop.getParent());
for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, bop, gen);
hi = gen;
LOG.debug("Applied fuseDatagenAndBinaryOperation2 " + "(" + bop.getFilename() + ", line " + bop.getBeginLine() + ").");
}
} else // left input rand and hence output matrix double, right scalar variable
if (HopRewriteUtils.isDataGenOp(left, DataGenMethod.RAND) && right.getDataType().isScalar() && left.getParent().size() == 1) {
DataGenOp gen = (DataGenOp) left;
Hop min = gen.getInput(DataExpression.RAND_MIN);
Hop max = gen.getInput(DataExpression.RAND_MAX);
Hop pdf = gen.getInput(DataExpression.RAND_PDF);
boolean pdfUniform = pdf instanceof LiteralOp && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp) pdf).getStringValue());
if (HopRewriteUtils.isBinary(bop, OpOp2.PLUS) && HopRewriteUtils.isLiteralOfValue(min, 0) && HopRewriteUtils.isLiteralOfValue(max, 0)) {
gen.setInput(DataExpression.RAND_MIN, right);
gen.setInput(DataExpression.RAND_MAX, right);
// rewire all parents (avoid anomalies with replicated datagen)
List<Hop> parents = new ArrayList<>(bop.getParent());
for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, bop, gen);
hi = gen;
LOG.debug("Applied fuseDatagenAndBinaryOperation3a " + "(" + bop.getFilename() + ", line " + bop.getBeginLine() + ").");
} else if (HopRewriteUtils.isBinary(bop, OpOp2.MULT) && ((HopRewriteUtils.isLiteralOfValue(min, 0) && pdfUniform) || HopRewriteUtils.isLiteralOfValue(min, 1)) && HopRewriteUtils.isLiteralOfValue(max, 1)) {
if (HopRewriteUtils.isLiteralOfValue(min, 1))
gen.setInput(DataExpression.RAND_MIN, right);
gen.setInput(DataExpression.RAND_MAX, right);
// rewire all parents (avoid anomalies with replicated datagen)
List<Hop> parents = new ArrayList<>(bop.getParent());
for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, bop, gen);
hi = gen;
LOG.debug("Applied fuseDatagenAndBinaryOperation3b " + "(" + bop.getFilename() + ", line " + bop.getBeginLine() + ").");
}
}
}
return hi;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method fuseLogNzBinaryOperation.
private static Hop fuseLogNzBinaryOperation(Hop parent, Hop hi, int pos) {
// memory estimate and to prevent dense intermediates if X is ultra sparse
if (HopRewriteUtils.isBinary(hi, OpOp2.MULT) && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(1).getDataType() == DataType.MATRIX && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.LOG)) {
Hop pred = hi.getInput().get(0);
Hop X = hi.getInput().get(1).getInput().get(0);
Hop log = hi.getInput().get(1).getInput().get(1);
if (HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL) && // depend on common subexpression elimination
pred.getInput().get(0) == X && pred.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) pred.getInput().get(1)) == 0) {
Hop hnew = HopRewriteUtils.createBinary(X, log, OpOp2.LOG_NZ);
// relink new hop into original position
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
LOG.debug("Applied fuseLogNzBinaryOperation (line " + hi.getBeginLine() + ")");
}
}
return hi;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifyOrderedSort.
private static Hop simplifyOrderedSort(Hop parent, Hop hi, int pos) {
// order(seq(2,N+1,1), indexreturn=TRUE) -> seq(1,N,1)/seq(N,1,-1)
if (// order
hi instanceof ReorgOp && ((ReorgOp) hi).getOp() == ReOrgOp.SORT) {
Hop hi2 = hi.getInput().get(0);
if (hi2 instanceof DataGenOp && ((DataGenOp) hi2).getOp() == DataGenMethod.SEQ) {
Hop incr = hi2.getInput().get(((DataGenOp) hi2).getParamIndex(Statement.SEQ_INCR));
// check for known ascending ordering and known indexreturn
if (incr instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) incr) == 1 && // decreasing
hi.getInput().get(2) instanceof LiteralOp && // indexreturn
hi.getInput().get(3) instanceof LiteralOp) {
if (// IXRET, ASC/DESC
HopRewriteUtils.getBooleanValue((LiteralOp) hi.getInput().get(3))) {
// order(seq(2,N+1,1), indexreturn=TRUE) -> seq(1,N,1)/seq(N,1,-1)
boolean desc = HopRewriteUtils.getBooleanValue((LiteralOp) hi.getInput().get(2));
Hop seq = HopRewriteUtils.createSeqDataGenOp(hi2, !desc);
seq.refreshSizeInformation();
HopRewriteUtils.replaceChildReference(parent, hi, seq, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = seq;
LOG.debug("Applied simplifyOrderedSort1.");
} else if (// DATA, ASC
!HopRewriteUtils.getBooleanValue((LiteralOp) hi.getInput().get(2))) {
// order(seq(2,N+1,1), indexreturn=FALSE) -> seq(2,N+1,1)
HopRewriteUtils.replaceChildReference(parent, hi, hi2, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = hi2;
LOG.debug("Applied simplifyOrderedSort2.");
}
}
}
}
return hi;
}
Aggregations