use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifySumMatrixMult.
private static Hop simplifySumMatrixMult(Hop parent, Hop hi, int pos) {
// -- if sum not the only consumer, not applied to prevent redundancy
if (// sum
hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM && // A%*%B
hi.getInput().get(0) instanceof AggBinaryOp && // not dot product
(hi.getInput().get(0).getDim1() > 1 || hi.getInput().get(0).getDim2() > 1) && // not multiple consumers of matrix mult
hi.getInput().get(0).getParent().size() == 1) {
Hop hi2 = hi.getInput().get(0);
Hop left = hi2.getInput().get(0);
Hop right = hi2.getInput().get(1);
// remove link from parent to matrix mult
HopRewriteUtils.removeChildReference(hi, hi2);
// create new operators
Hop root = null;
// pattern: sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), later rewritten to dot-product
if (((AggUnaryOp) hi).getDirection() == Direction.RowCol) {
AggUnaryOp colSum = HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col);
ReorgOp trans = HopRewriteUtils.createTranspose(colSum);
AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row);
root = HopRewriteUtils.createBinary(trans, rowSum, OpOp2.MULT);
LOG.debug("Applied simplifySumMatrixMult RC.");
} else // colSums(A%*%B) -> colSums(A)%*%B
if (((AggUnaryOp) hi).getDirection() == Direction.Col) {
AggUnaryOp colSum = HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col);
root = HopRewriteUtils.createMatrixMultiply(colSum, right);
LOG.debug("Applied simplifySumMatrixMult C.");
} else // rowSums(A%*%B) -> A%*%rowSums(B)
if (((AggUnaryOp) hi).getDirection() == Direction.Row) {
AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row);
root = HopRewriteUtils.createMatrixMultiply(left, rowSum);
LOG.debug("Applied simplifySumMatrixMult R.");
}
// rehang new subdag under current node (keep hi intact)
HopRewriteUtils.addChildReference(hi, root, 0);
hi.refreshSizeInformation();
// cleanup if only consumer of intermediate
HopRewriteUtils.cleanupUnreferenced(hi2);
}
return hi;
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method fuseAxpyBinaryOperationChain.
private static Hop fuseAxpyBinaryOperationChain(Hop parent, Hop hi, int pos) {
// patterns: (a) X + s*Y -> X +* sY, (b) s*Y+X -> X +* sY, (c) X - s*Y -> X -* sY
if (hi instanceof BinaryOp && !((BinaryOp) hi).isOuterVectorOperator() && (((BinaryOp) hi).getOp() == OpOp2.PLUS || ((BinaryOp) hi).getOp() == OpOp2.MINUS)) {
BinaryOp bop = (BinaryOp) hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
Hop ternop = null;
// pattern (a) X + s*Y -> X +* sY
if (bop.getOp() == OpOp2.PLUS && left.getDataType() == DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(right) && HopRewriteUtils.isEqualSize(left, right) && // single consumer s*Y
right.getParent().size() == 1) {
Hop smid = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 0 : 1);
Hop mright = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 1 : 0);
ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) smid) == 0) ? left : HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.PLUS_MULT);
LOG.debug("Applied fuseAxpyBinaryOperationChain1. (line " + hi.getBeginLine() + ")");
} else // pattern (b) s*Y + X -> X +* sY
if (bop.getOp() == OpOp2.PLUS && right.getDataType() == DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(left) && HopRewriteUtils.isEqualSize(left, right) && // single consumer s*Y
left.getParent().size() == 1) {
Hop smid = left.getInput().get((left.getInput().get(0).getDataType() == DataType.SCALAR) ? 0 : 1);
Hop mright = left.getInput().get((left.getInput().get(0).getDataType() == DataType.SCALAR) ? 1 : 0);
ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) smid) == 0) ? right : HopRewriteUtils.createTernaryOp(right, smid, mright, OpOp3.PLUS_MULT);
LOG.debug("Applied fuseAxpyBinaryOperationChain2. (line " + hi.getBeginLine() + ")");
} else // pattern (c) X - s*Y -> X -* sY
if (bop.getOp() == OpOp2.MINUS && left.getDataType() == DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(right) && HopRewriteUtils.isEqualSize(left, right) && // single consumer s*Y
right.getParent().size() == 1) {
Hop smid = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 0 : 1);
Hop mright = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 1 : 0);
ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) smid) == 0) ? left : HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.MINUS_MULT);
LOG.debug("Applied fuseAxpyBinaryOperationChain3. (line " + hi.getBeginLine() + ")");
}
// rewire parent-child operators if rewrite applied
if (ternop != null) {
HopRewriteUtils.replaceChildReference(parent, hi, ternop, pos);
hi = ternop;
}
}
return hi;
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method removeUnnecessaryRightIndexing.
private static Hop removeUnnecessaryRightIndexing(Hop parent, Hop hi, int pos) {
if (HopRewriteUtils.isUnnecessaryRightIndexing(hi)) {
// remove unnecessary right indexing
Hop input = hi.getInput().get(0);
HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = input;
LOG.debug("Applied removeUnnecessaryRightIndexing");
}
return hi;
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyEmptyMatrixMult.
private static Hop simplifyEmptyMatrixMult(Hop parent, Hop hi, int pos) {
if (// X%*%Y -> matrix(0, )
HopRewriteUtils.isMatrixMultiply(hi)) {
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
if (// one input empty
HopRewriteUtils.isEmpty(left) || HopRewriteUtils.isEmpty(right)) {
// create datagen and add it to parent
Hop hnew = HopRewriteUtils.createDataGenOp(left, right, 0);
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
LOG.debug("Applied simplifyEmptyMatrixMult");
}
}
return hi;
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method removeUnnecessaryLeftIndexing.
private static Hop removeUnnecessaryLeftIndexing(Hop parent, Hop hi, int pos) {
if (// left indexing op
hi instanceof LeftIndexingOp) {
// rhs matrix/frame
Hop input = hi.getInput().get(1);
if (// equal dims
HopRewriteUtils.isEqualSize(hi, input)) {
// equal dims of left indexing input and output -> no need for indexing
// remove unnecessary right indexing
HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = input;
LOG.debug("Applied removeUnnecessaryLeftIndexing");
}
}
return hi;
}
Aggregations