use of org.apache.sysml.hops.UnaryOp in project systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedUnaryMM.
private static Hop simplifyWeightedUnaryMM(Hop parent, Hop hi, int pos) {
Hop hnew = null;
boolean appliedPattern = false;
// Pattern 1) (W*uop(U%*%t(V)))
if (hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && // not applied for vector-vector mult
hi.getDim2() > 1 && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && hi.getInput().get(1) instanceof UnaryOp && HopRewriteUtils.isValidOp(((UnaryOp) hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_UNARY) && hi.getInput().get(1).getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop U = hi.getInput().get(1).getInput().get(0).getInput().get(0);
Hop V = hi.getInput().get(1).getInput().get(0).getInput().get(1);
boolean mult = ((BinaryOp) hi).getOp() == OpOp2.MULT;
OpOp1 op = ((UnaryOp) hi.getInput().get(1)).getOp();
if (!HopRewriteUtils.isTransposeOperation(V))
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, mult, op, null);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedUnaryMM1 (line " + hi.getBeginLine() + ")");
}
// Pattern 2.7) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
if (!appliedPattern && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), OpOp2.MULT) && (HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 2) || HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 2))) {
// non-literal
final Hop nl;
if (hi.getInput().get(0) instanceof LiteralOp) {
nl = hi.getInput().get(1);
} else {
nl = hi.getInput().get(0);
}
if (HopRewriteUtils.isBinary(nl, OpOp2.MULT) && // ensure no foreign parents
nl.getParent().size() == 1 && // prevent mv
HopRewriteUtils.isEqualSize(nl.getInput().get(0), nl.getInput().get(1)) && // not applied for vector-vector mult
nl.getDim2() > 1 && nl.getInput().get(0).getDataType() == DataType.MATRIX && nl.getInput().get(0).getDim2() > nl.getInput().get(0).getColsInBlock() && HopRewriteUtils.isOuterProductLikeMM(nl.getInput().get(1)) && // no mmchain
(((AggBinaryOp) nl.getInput().get(1)).checkMapMultChain() == ChainType.NONE || nl.getInput().get(1).getInput().get(1).getDim2() > 1) && HopRewriteUtils.isSingleBlock(nl.getInput().get(1).getInput().get(0), true)) {
final Hop W = nl.getInput().get(0);
final Hop U = nl.getInput().get(1).getInput().get(0);
Hop V = nl.getInput().get(1).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V))
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, true, null, OpOp2.MULT);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedUnaryMM2.7 (line " + hi.getBeginLine() + ")");
}
}
// Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to unary ops
if (!appliedPattern && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && // not applied for vector-vector mult
hi.getDim2() > 1 && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && hi.getInput().get(1) instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_BINARY)) {
Hop left = hi.getInput().get(1).getInput().get(0);
Hop right = hi.getInput().get(1).getInput().get(1);
Hop abop = null;
// pattern 2a) matrix-scalar operations
if (right.getDataType() == DataType.SCALAR && right instanceof LiteralOp && // pow2, mult2
HopRewriteUtils.getDoubleValue((LiteralOp) right) == 2 && left instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(left.getInput().get(0), true)) {
abop = left;
} else // pattern 2b) scalar-matrix operations
if (left.getDataType() == DataType.SCALAR && left instanceof LiteralOp && // mult2
HopRewriteUtils.getDoubleValue((LiteralOp) left) == 2 && ((BinaryOp) hi.getInput().get(1)).getOp() == OpOp2.MULT && right instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(right.getInput().get(0), true)) {
abop = right;
}
if (abop != null) {
Hop W = hi.getInput().get(0);
Hop U = abop.getInput().get(0);
Hop V = abop.getInput().get(1);
boolean mult = ((BinaryOp) hi).getOp() == OpOp2.MULT;
OpOp2 op = ((BinaryOp) hi.getInput().get(1)).getOp();
if (!HopRewriteUtils.isTransposeOperation(V))
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, mult, null, op);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedUnaryMM2 (line " + hi.getBeginLine() + ")");
}
}
// relink new hop into original position
if (hnew != null) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
return hi;
}
use of org.apache.sysml.hops.UnaryOp in project systemml by apache.
the class HopRewriteUtils method rContainsRead.
public static boolean rContainsRead(Hop root, String var, boolean includeMetaOp) {
if (root.isVisited())
return false;
boolean ret = false;
// handle leaf node for variable
if (root instanceof DataOp && ((DataOp) root).isRead() && root.getName().equals(var)) {
boolean onlyMetaOp = true;
if (!includeMetaOp) {
for (Hop p : root.getParent()) {
onlyMetaOp &= (p instanceof UnaryOp && (((UnaryOp) p).getOp() == OpOp1.NROW || ((UnaryOp) p).getOp() == OpOp1.NCOL));
}
ret = !onlyMetaOp;
} else
ret = true;
}
// recursively process childs
for (Hop c : root.getInput()) ret |= rContainsRead(c, var, includeMetaOp);
root.setVisited();
return ret;
}
Aggregations