use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class DMLTranslator method processRelationalExpression.
private Hop processRelationalExpression(RelationalExpression source, DataIdentifier target, HashMap<String, Hop> hops) {
Hop left = processExpression(source.getLeft(), null, hops);
Hop right = processExpression(source.getRight(), null, hops);
Hop currBop = null;
if (target == null) {
target = createTarget(source);
if (left.getDataType() == DataType.MATRIX || right.getDataType() == DataType.MATRIX) {
// Added to support matrix relational comparison
// (we support only matrices of value type double)
target.setDataType(DataType.MATRIX);
target.setValueType(ValueType.DOUBLE);
} else {
// Added to support scalar relational comparison
target.setDataType(DataType.SCALAR);
target.setValueType(ValueType.BOOLEAN);
}
}
OpOp2 op = null;
if (source.getOpCode() == Expression.RelationalOp.LESS) {
op = OpOp2.LESS;
} else if (source.getOpCode() == Expression.RelationalOp.LESSEQUAL) {
op = OpOp2.LESSEQUAL;
} else if (source.getOpCode() == Expression.RelationalOp.GREATER) {
op = OpOp2.GREATER;
} else if (source.getOpCode() == Expression.RelationalOp.GREATEREQUAL) {
op = OpOp2.GREATEREQUAL;
} else if (source.getOpCode() == Expression.RelationalOp.EQUAL) {
op = OpOp2.EQUAL;
} else if (source.getOpCode() == Expression.RelationalOp.NOTEQUAL) {
op = OpOp2.NOTEQUAL;
}
currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), op, left, right);
currBop.setParseInfo(source);
return currBop;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method removeUnecessaryPPred.
/**
* NOTE: currently disabled since this rewrite is INVALID in the
* presence of NaNs (because (NaN!=NaN) is true).
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
@SuppressWarnings("unused")
private static Hop removeUnecessaryPPred(Hop parent, Hop hi, int pos) {
if (hi instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
Hop datagen = null;
// ppred(X,X,"==") -> matrix(1, rows=nrow(X),cols=nrow(Y))
if (left == right && bop.getOp() == OpOp2.EQUAL || bop.getOp() == OpOp2.GREATEREQUAL || bop.getOp() == OpOp2.LESSEQUAL)
datagen = HopRewriteUtils.createDataGenOp(left, 1);
// ppred(X,X,"!=") -> matrix(0, rows=nrow(X),cols=nrow(Y))
if (left == right && bop.getOp() == OpOp2.NOTEQUAL || bop.getOp() == OpOp2.GREATER || bop.getOp() == OpOp2.LESS)
datagen = HopRewriteUtils.createDataGenOp(left, 0);
if (datagen != null) {
HopRewriteUtils.replaceChildReference(parent, hi, datagen, pos);
hi = datagen;
}
}
return hi;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method foldMultipleAppendOperations.
private static Hop foldMultipleAppendOperations(Hop hi) {
if (// no string appends or frames
hi.getDataType().isMatrix() && (HopRewriteUtils.isBinary(hi, OpOp2.CBIND, OpOp2.RBIND) || HopRewriteUtils.isNary(hi, OpOpN.CBIND, OpOpN.RBIND)) && !OptimizerUtils.isHadoopExecutionMode()) {
OpOp2 bop = (hi instanceof BinaryOp) ? ((BinaryOp) hi).getOp() : OpOp2.valueOf(((NaryOp) hi).getOp().name());
OpOpN nop = (hi instanceof NaryOp) ? ((NaryOp) hi).getOp() : OpOpN.valueOf(((BinaryOp) hi).getOp().name());
boolean converged = false;
while (!converged) {
// get first matching cbind or rbind
Hop first = hi.getInput().stream().filter(h -> HopRewriteUtils.isBinary(h, bop) || HopRewriteUtils.isNary(h, nop)).findFirst().orElse(null);
// replace current op with new nary cbind/rbind
if (first != null && first.getParent().size() == 1) {
// construct new list of inputs (in original order)
ArrayList<Hop> linputs = new ArrayList<>();
for (Hop in : hi.getInput()) if (in == first)
linputs.addAll(first.getInput());
else
linputs.add(in);
Hop hnew = HopRewriteUtils.createNary(nop, linputs.toArray(new Hop[0]));
// clear dangling references
HopRewriteUtils.removeAllChildReferences(hi);
HopRewriteUtils.removeAllChildReferences(first);
// rewire all parents (avoid anomalies with refs to hi)
List<Hop> parents = new ArrayList<>(hi.getParent());
for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, hi, hnew);
hi = hnew;
LOG.debug("Applied foldMultipleAppendOperations (line " + hi.getBeginLine() + ").");
} else {
converged = true;
}
}
}
return hi;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method fuseDatagenAndMinusOperation.
private static Hop fuseDatagenAndMinusOperation(Hop hi) {
if (hi instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
if (right instanceof DataGenOp && ((DataGenOp) right).getOp() == DataGenMethod.RAND && left instanceof LiteralOp && ((LiteralOp) left).getDoubleValue() == 0.0) {
DataGenOp inputGen = (DataGenOp) right;
HashMap<String, Integer> params = inputGen.getParamIndexMap();
Hop pdf = right.getInput().get(params.get(DataExpression.RAND_PDF));
int ixMin = params.get(DataExpression.RAND_MIN);
int ixMax = params.get(DataExpression.RAND_MAX);
Hop min = right.getInput().get(ixMin);
Hop max = right.getInput().get(ixMax);
// apply rewrite under additional conditions (for simplicity)
if (inputGen.getParent().size() == 1 && min instanceof LiteralOp && max instanceof LiteralOp && pdf instanceof LiteralOp && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp) pdf).getStringValue())) {
// exchange and *-1 (special case 0 stays 0 instead of -0 for consistency)
double newMinVal = (((LiteralOp) max).getDoubleValue() == 0) ? 0 : (-1 * ((LiteralOp) max).getDoubleValue());
double newMaxVal = (((LiteralOp) min).getDoubleValue() == 0) ? 0 : (-1 * ((LiteralOp) min).getDoubleValue());
Hop newMin = new LiteralOp(newMinVal);
Hop newMax = new LiteralOp(newMaxVal);
HopRewriteUtils.removeChildReferenceByPos(inputGen, min, ixMin);
HopRewriteUtils.addChildReference(inputGen, newMin, ixMin);
HopRewriteUtils.removeChildReferenceByPos(inputGen, max, ixMax);
HopRewriteUtils.addChildReference(inputGen, newMax, ixMax);
// rewire all parents (avoid anomalies with replicated datagen)
List<Hop> parents = new ArrayList<>(bop.getParent());
for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, bop, inputGen);
hi = inputGen;
LOG.debug("Applied fuseDatagenAndMinusOperation (line " + bop.getBeginLine() + ").");
}
}
}
return hi;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteElementwiseMultChainOptimization method constructReplacement.
private static Hop constructReplacement(final Set<BinaryOp> emults, final Map<Hop, Integer> leaves) {
// Sort by data type
final SortedSet<Hop> sorted = new TreeSet<>(compareByDataType);
for (final Map.Entry<Hop, Integer> entry : leaves.entrySet()) {
final Hop h = entry.getKey();
// unlink parents that are in the emult set(we are throwing them away)
// keep other parents
h.getParent().removeIf(parent -> parent instanceof BinaryOp && emults.contains(parent));
sorted.add(constructPower(h, entry.getValue()));
}
// sorted contains all leaves, sorted by data type, stripped from their parents
// Construct right-deep EMult tree
final Iterator<Hop> iterator = sorted.iterator();
Hop next = iterator.hasNext() ? iterator.next() : null;
Hop colVectorsScalars = null;
while (next != null && (next.getDataType() == Expression.DataType.SCALAR || next.getDataType() == Expression.DataType.MATRIX && next.getDim2() == 1)) {
if (colVectorsScalars == null)
colVectorsScalars = next;
else {
colVectorsScalars = HopRewriteUtils.createBinary(next, colVectorsScalars, Hop.OpOp2.MULT);
colVectorsScalars.setVisited();
}
next = iterator.hasNext() ? iterator.next() : null;
}
// next is not processed and is either null or past col vectors
Hop rowVectors = null;
while (next != null && (next.getDataType() == Expression.DataType.MATRIX && next.getDim1() == 1)) {
if (rowVectors == null)
rowVectors = next;
else {
rowVectors = HopRewriteUtils.createBinary(rowVectors, next, Hop.OpOp2.MULT);
rowVectors.setVisited();
}
next = iterator.hasNext() ? iterator.next() : null;
}
// next is not processed and is either null or past row vectors
Hop matrices = null;
while (next != null && (next.getDataType() == Expression.DataType.MATRIX)) {
if (matrices == null)
matrices = next;
else {
matrices = HopRewriteUtils.createBinary(matrices, next, Hop.OpOp2.MULT);
matrices.setVisited();
}
next = iterator.hasNext() ? iterator.next() : null;
}
// next is not processed and is either null or past matrices
Hop other = null;
while (next != null) {
if (other == null)
other = next;
else {
other = HopRewriteUtils.createBinary(other, next, Hop.OpOp2.MULT);
other.setVisited();
}
next = iterator.hasNext() ? iterator.next() : null;
}
// finished
// ((other * matrices) * rowVectors) * colVectorsScalars
Hop top = null;
if (other == null && matrices != null)
top = matrices;
else if (other != null && matrices == null)
top = other;
else if (other != null) {
// matrices != null
top = HopRewriteUtils.createBinary(other, matrices, Hop.OpOp2.MULT);
top.setVisited();
}
if (top == null && rowVectors != null)
top = rowVectors;
else if (rowVectors != null) {
// top != null
top = HopRewriteUtils.createBinary(top, rowVectors, Hop.OpOp2.MULT);
top.setVisited();
}
if (top == null && colVectorsScalars != null)
top = colVectorsScalars;
else if (colVectorsScalars != null) {
// top != null
top = HopRewriteUtils.createBinary(top, colVectorsScalars, Hop.OpOp2.MULT);
top.setVisited();
}
return top;
}
Aggregations