use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedCrossEntropy.
private static Hop simplifyWeightedCrossEntropy(Hop parent, Hop hi, int pos) {
Hop hnew = null;
boolean appliedPattern = false;
if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // pattern rooted by sum()
((AggUnaryOp) hi).getOp() == AggOp.SUM && // pattern subrooted by binary op
hi.getInput().get(0) instanceof BinaryOp && // not applied for vector-vector mult
hi.getInput().get(0).getDim2() > 1) {
BinaryOp bop = (BinaryOp) hi.getInput().get(0);
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
// Pattern 1) sum( X * log(U %*% t(V)))
if (bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && // prevent mb
HopRewriteUtils.isEqualSize(left, right) && HopRewriteUtils.isUnary(right, OpOp1.LOG) && // ba gurantees matrices
right.getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0), true)) {
Hop X = left;
Hop U = right.getInput().get(0).getInput().get(0);
Hop V = right.getInput().get(0).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V))
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V, new LiteralOp(0.0), 0, false, false);
hnew.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
appliedPattern = true;
LOG.debug("Applied simplifyWeightedCEMM (line " + hi.getBeginLine() + ")");
}
// Pattern 2) sum( X * log(U %*% t(V) + eps))
if (!appliedPattern && bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && HopRewriteUtils.isEqualSize(left, right) && HopRewriteUtils.isUnary(right, OpOp1.LOG) && HopRewriteUtils.isBinary(right.getInput().get(0), OpOp2.PLUS) && right.getInput().get(0).getInput().get(0) instanceof AggBinaryOp && right.getInput().get(0).getInput().get(1) instanceof LiteralOp && right.getInput().get(0).getInput().get(1).getDataType() == DataType.SCALAR && HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0).getInput().get(0), true)) {
Hop X = left;
Hop U = right.getInput().get(0).getInput().get(0).getInput().get(0);
Hop V = right.getInput().get(0).getInput().get(0).getInput().get(1);
Hop eps = right.getInput().get(0).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V))
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V, eps, 1, false, // 1 => BASIC_EPS
false);
hnew.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
LOG.debug("Applied simplifyWeightedCEMMEps (line " + hi.getBeginLine() + ")");
}
}
// relink new hop into original position
if (hnew != null) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
return hi;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method removeUnnecessaryOuterProduct.
private static Hop removeUnnecessaryOuterProduct(Hop parent, Hop hi, int pos) {
if (// binary cell operation
hi instanceof BinaryOp) {
OpOp2 bop = ((BinaryOp) hi).getOp();
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
// check for matrix-vector column replication: (A + b %*% ones) -> (A + b)
if (// matrix mult with datagen
HopRewriteUtils.isMatrixMultiply(right) && HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(1), 1) && // column vector for mv binary
right.getInput().get(0).getDim2() == 1) {
// remove unnecessary outer product
HopRewriteUtils.replaceChildReference(hi, right, right.getInput().get(0), 1);
HopRewriteUtils.cleanupUnreferenced(right);
LOG.debug("Applied removeUnnecessaryOuterProduct1 (line " + right.getBeginLine() + ")");
} else // check for matrix-vector row replication: (A + ones %*% b) -> (A + b)
if (// matrix mult with datagen
HopRewriteUtils.isMatrixMultiply(right) && HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(0), 1) && // row vector for mv binary
right.getInput().get(1).getDim1() == 1) {
// remove unnecessary outer product
HopRewriteUtils.replaceChildReference(hi, right, right.getInput().get(1), 1);
HopRewriteUtils.cleanupUnreferenced(right);
LOG.debug("Applied removeUnnecessaryOuterProduct2 (line " + right.getBeginLine() + ")");
} else // check for vector-vector column replication: (a %*% ones) == b) -> outer(a, b, "==")
if (HopRewriteUtils.isValidOuterBinaryOp(bop) && HopRewriteUtils.isMatrixMultiply(left) && HopRewriteUtils.isDataGenOpWithConstantValue(left.getInput().get(1), 1) && (// outer product
left.getInput().get(0).getDim2() == 1 || left.getInput().get(1).getDim1() == 1) && left.getDim1() != 1 && // outer vector binary
right.getDim1() == 1) {
Hop hnew = HopRewriteUtils.createBinary(left.getInput().get(0), right, bop, true);
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = hnew;
LOG.debug("Applied removeUnnecessaryOuterProduct3 (line " + right.getBeginLine() + ")");
}
}
return hi;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyEmptyBinaryOperation.
private static Hop simplifyEmptyBinaryOperation(Hop parent, Hop hi, int pos) {
if (// b(?) X Y
hi instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hi;
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
if (left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX) {
Hop hnew = null;
// NOTE: these rewrites of binary cell operations need to be aware that right is
// potentially a vector but the result is of the size of left
boolean notBinaryMV = HopRewriteUtils.isNotMatrixVectorBinaryOperation(bop);
switch(bop.getOp()) {
// X * Y -> matrix(0,nrow(X),ncol(X));
case MULT:
{
if (// empty left and size known
HopRewriteUtils.isEmpty(left))
hnew = HopRewriteUtils.createDataGenOp(left, left, 0);
else if (// empty right and right not a vector
HopRewriteUtils.isEmpty(right) && right.getDim1() > 1 && right.getDim2() > 1) {
hnew = HopRewriteUtils.createDataGenOp(right, right, 0);
} else if (// empty right and right potentially a vector
HopRewriteUtils.isEmpty(right))
hnew = HopRewriteUtils.createDataGenOp(left, left, 0);
break;
}
case PLUS:
{
if (// empty left/right and size known
HopRewriteUtils.isEmpty(left) && HopRewriteUtils.isEmpty(right))
hnew = HopRewriteUtils.createDataGenOp(left, left, 0);
else if (// empty left
HopRewriteUtils.isEmpty(left) && notBinaryMV)
hnew = right;
else if (// empty right
HopRewriteUtils.isEmpty(right))
hnew = left;
break;
}
case MINUS:
{
if (HopRewriteUtils.isEmpty(left) && notBinaryMV) {
// empty left
HopRewriteUtils.removeChildReference(hi, left);
HopRewriteUtils.addChildReference(hi, new LiteralOp(0), 0);
hnew = hi;
} else if (// empty and size known
HopRewriteUtils.isEmpty(right))
hnew = left;
break;
}
default:
hnew = null;
}
if (hnew != null) {
// create datagen and add it to parent
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
LOG.debug("Applied simplifyEmptyBinaryOperation");
}
}
}
return hi;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method reorderMinusMatrixMult.
/**
* This is rewrite tries to reorder minus operators from inputs of matrix
* multiply to its output because the output is (except for outer products)
* usually significantly smaller. Furthermore, this rewrite is a precondition
* for the important hops-lops rewrite of transpose-matrixmult if the transpose
* is hidden under the minus.
*
* NOTE: in this rewrite we need to modify the links to all parents because we
* remove existing links of subdags and hence affect all consumers.
*
* @param parent the parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
@SuppressWarnings("unchecked")
private static Hop reorderMinusMatrixMult(Hop parent, Hop hi, int pos) {
if (// X%*%Y
HopRewriteUtils.isMatrixMultiply(hi)) {
Hop hileft = hi.getInput().get(0);
Hop hiright = hi.getInput().get(1);
if (// X=-Z
HopRewriteUtils.isBinary(hileft, OpOp2.MINUS) && hileft.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) hileft.getInput().get(0)) == 0.0 && hi.dimsKnown() && // size comparison
hileft.getInput().get(1).dimsKnown() && HopRewriteUtils.compareSize(hi, hileft.getInput().get(1)) < 0) {
Hop hi2 = hileft.getInput().get(1);
// remove link from matrixmult to minus
HopRewriteUtils.removeChildReference(hi, hileft);
// get old parents (before creating minus over matrix mult)
ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone();
// create new operators
BinaryOp minus = HopRewriteUtils.createBinary(new LiteralOp(0), hi, OpOp2.MINUS);
// rehang minus under all parents
for (Hop p : parents) {
int ix = HopRewriteUtils.getChildReferencePos(p, hi);
HopRewriteUtils.removeChildReference(p, hi);
HopRewriteUtils.addChildReference(p, minus, ix);
}
// rehang child of minus under matrix mult
HopRewriteUtils.addChildReference(hi, hi2, 0);
// cleanup if only consumer of minus
HopRewriteUtils.cleanupUnreferenced(hileft);
hi = minus;
LOG.debug("Applied reorderMinusMatrixMult (line " + hi.getBeginLine() + ").");
} else if (// X=-Z
HopRewriteUtils.isBinary(hiright, OpOp2.MINUS) && hiright.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) hiright.getInput().get(0)) == 0.0 && hi.dimsKnown() && // size comparison
hiright.getInput().get(1).dimsKnown() && HopRewriteUtils.compareSize(hi, hiright.getInput().get(1)) < 0) {
Hop hi2 = hiright.getInput().get(1);
// remove link from matrixmult to minus
HopRewriteUtils.removeChildReference(hi, hiright);
// get old parents (before creating minus over matrix mult)
ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone();
// create new operators
BinaryOp minus = HopRewriteUtils.createBinary(new LiteralOp(0), hi, OpOp2.MINUS);
// rehang minus under all parents
for (Hop p : parents) {
int ix = HopRewriteUtils.getChildReferencePos(p, hi);
HopRewriteUtils.removeChildReference(p, hi);
HopRewriteUtils.addChildReference(p, minus, ix);
}
// rehang child of minus under matrix mult
HopRewriteUtils.addChildReference(hi, hi2, 1);
// cleanup if only consumer of minus
HopRewriteUtils.cleanupUnreferenced(hiright);
hi = minus;
LOG.debug("Applied reorderMinusMatrixMult (line " + hi.getBeginLine() + ").");
}
}
return hi;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedUnaryMM.
private static Hop simplifyWeightedUnaryMM(Hop parent, Hop hi, int pos) {
Hop hnew = null;
boolean appliedPattern = false;
// Pattern 1) (W*uop(U%*%t(V)))
if (hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && // not applied for vector-vector mult
hi.getDim2() > 1 && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && hi.getInput().get(1) instanceof UnaryOp && HopRewriteUtils.isValidOp(((UnaryOp) hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_UNARY) && hi.getInput().get(1).getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop U = hi.getInput().get(1).getInput().get(0).getInput().get(0);
Hop V = hi.getInput().get(1).getInput().get(0).getInput().get(1);
boolean mult = ((BinaryOp) hi).getOp() == OpOp2.MULT;
OpOp1 op = ((UnaryOp) hi.getInput().get(1)).getOp();
if (!HopRewriteUtils.isTransposeOperation(V))
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, mult, op, null);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedUnaryMM1 (line " + hi.getBeginLine() + ")");
}
// Pattern 2.7) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
if (!appliedPattern && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), OpOp2.MULT) && (HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 2) || HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 2))) {
// non-literal
final Hop nl;
if (hi.getInput().get(0) instanceof LiteralOp) {
nl = hi.getInput().get(1);
} else {
nl = hi.getInput().get(0);
}
if (HopRewriteUtils.isBinary(nl, OpOp2.MULT) && // ensure no foreign parents
nl.getParent().size() == 1 && // prevent mv
HopRewriteUtils.isEqualSize(nl.getInput().get(0), nl.getInput().get(1)) && // not applied for vector-vector mult
nl.getDim2() > 1 && nl.getInput().get(0).getDataType() == DataType.MATRIX && nl.getInput().get(0).getDim2() > nl.getInput().get(0).getColsInBlock() && HopRewriteUtils.isOuterProductLikeMM(nl.getInput().get(1)) && // no mmchain
(((AggBinaryOp) nl.getInput().get(1)).checkMapMultChain() == ChainType.NONE || nl.getInput().get(1).getInput().get(1).getDim2() > 1) && HopRewriteUtils.isSingleBlock(nl.getInput().get(1).getInput().get(0), true)) {
final Hop W = nl.getInput().get(0);
final Hop U = nl.getInput().get(1).getInput().get(0);
Hop V = nl.getInput().get(1).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V))
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, true, null, OpOp2.MULT);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedUnaryMM2.7 (line " + hi.getBeginLine() + ")");
}
}
// Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to unary ops
if (!appliedPattern && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && // not applied for vector-vector mult
hi.getDim2() > 1 && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && hi.getInput().get(1) instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_BINARY)) {
Hop left = hi.getInput().get(1).getInput().get(0);
Hop right = hi.getInput().get(1).getInput().get(1);
Hop abop = null;
// pattern 2a) matrix-scalar operations
if (right.getDataType() == DataType.SCALAR && right instanceof LiteralOp && // pow2, mult2
HopRewriteUtils.getDoubleValue((LiteralOp) right) == 2 && left instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(left.getInput().get(0), true)) {
abop = left;
} else // pattern 2b) scalar-matrix operations
if (left.getDataType() == DataType.SCALAR && left instanceof LiteralOp && // mult2
HopRewriteUtils.getDoubleValue((LiteralOp) left) == 2 && ((BinaryOp) hi.getInput().get(1)).getOp() == OpOp2.MULT && right instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(right.getInput().get(0), true)) {
abop = right;
}
if (abop != null) {
Hop W = hi.getInput().get(0);
Hop U = abop.getInput().get(0);
Hop V = abop.getInput().get(1);
boolean mult = ((BinaryOp) hi).getOp() == OpOp2.MULT;
OpOp2 op = ((BinaryOp) hi.getInput().get(1)).getOp();
if (!HopRewriteUtils.isTransposeOperation(V))
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, mult, null, op);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedUnaryMM2 (line " + hi.getBeginLine() + ")");
}
}
// relink new hop into original position
if (hnew != null) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
return hi;
}
Aggregations