use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method fuseAxpyBinaryOperationChain.
private static Hop fuseAxpyBinaryOperationChain(Hop parent, Hop hi, int pos) {
// patterns: (a) X + s*Y -> X +* sY, (b) s*Y+X -> X +* sY, (c) X - s*Y -> X -* sY
if (hi instanceof BinaryOp && !((BinaryOp) hi).isOuterVectorOperator() && (((BinaryOp) hi).getOp() == OpOp2.PLUS || ((BinaryOp) hi).getOp() == OpOp2.MINUS)) {
BinaryOp bop = (BinaryOp) hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
Hop ternop = null;
// pattern (a) X + s*Y -> X +* sY
if (bop.getOp() == OpOp2.PLUS && left.getDataType() == DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(right) && HopRewriteUtils.isEqualSize(left, right) && // single consumer s*Y
right.getParent().size() == 1) {
Hop smid = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 0 : 1);
Hop mright = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 1 : 0);
ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) smid) == 0) ? left : HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.PLUS_MULT);
LOG.debug("Applied fuseAxpyBinaryOperationChain1. (line " + hi.getBeginLine() + ")");
} else // pattern (b) s*Y + X -> X +* sY
if (bop.getOp() == OpOp2.PLUS && right.getDataType() == DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(left) && HopRewriteUtils.isEqualSize(left, right) && // single consumer s*Y
left.getParent().size() == 1) {
Hop smid = left.getInput().get((left.getInput().get(0).getDataType() == DataType.SCALAR) ? 0 : 1);
Hop mright = left.getInput().get((left.getInput().get(0).getDataType() == DataType.SCALAR) ? 1 : 0);
ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) smid) == 0) ? right : HopRewriteUtils.createTernaryOp(right, smid, mright, OpOp3.PLUS_MULT);
LOG.debug("Applied fuseAxpyBinaryOperationChain2. (line " + hi.getBeginLine() + ")");
} else // pattern (c) X - s*Y -> X -* sY
if (bop.getOp() == OpOp2.MINUS && left.getDataType() == DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(right) && HopRewriteUtils.isEqualSize(left, right) && // single consumer s*Y
right.getParent().size() == 1) {
Hop smid = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 0 : 1);
Hop mright = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 1 : 0);
ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) smid) == 0) ? left : HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.MINUS_MULT);
LOG.debug("Applied fuseAxpyBinaryOperationChain3. (line " + hi.getBeginLine() + ")");
}
// rewire parent-child operators if rewrite applied
if (ternop != null) {
HopRewriteUtils.replaceChildReference(parent, hi, ternop, pos);
hi = ternop;
}
}
return hi;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyEmptyReorgOperation.
private static Hop simplifyEmptyReorgOperation(Hop parent, Hop hi, int pos) {
if (hi instanceof ReorgOp) {
ReorgOp rhi = (ReorgOp) hi;
Hop input = rhi.getInput().get(0);
if (// empty input
HopRewriteUtils.isEmpty(input)) {
// reorg-operation-specific rewrite
Hop hnew = null;
if (rhi.getOp() == ReOrgOp.TRANSPOSE)
hnew = HopRewriteUtils.createDataGenOp(input, true, input, true, 0);
else if (rhi.getOp() == ReOrgOp.REV)
hnew = HopRewriteUtils.createDataGenOp(input, 0);
else if (rhi.getOp() == ReOrgOp.DIAG) {
if (HopRewriteUtils.isDimsKnown(input)) {
if (// diagv2m
input.getDim2() == 1)
hnew = HopRewriteUtils.createDataGenOp(input, false, input, true, 0);
else
// diagm2v
hnew = HopRewriteUtils.createDataGenOpByVal(HopRewriteUtils.createValueHop(input, true), new LiteralOp(1), 0);
}
} else if (rhi.getOp() == ReOrgOp.RESHAPE)
hnew = HopRewriteUtils.createDataGenOpByVal(rhi.getInput().get(1), rhi.getInput().get(2), 0);
// modify dag if one of the above rules applied
if (hnew != null) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
LOG.debug("Applied simplifyEmptyReorgOperation");
}
}
}
return hi;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSigmoidMMChains.
private static Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos) {
Hop hnew = null;
if (// all patterns subrooted by W *
HopRewriteUtils.isBinary(hi, OpOp2.MULT) && // not applied for vector-vector mult
hi.getDim2() > 1 && // prevent mv
HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && hi.getInput().get(0).getDataType() == DataType.MATRIX && // sigmoid/log
hi.getInput().get(1) instanceof UnaryOp) {
UnaryOp uop = (UnaryOp) hi.getInput().get(1);
boolean appliedPattern = false;
// Pattern 1) W * sigmoid(Y%*%t(X)) (basic)
if (uop.getOp() == OpOp1.SIGMOID && uop.getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop Y = uop.getInput().get(0).getInput().get(0);
Hop tX = uop.getInput().get(0).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(tX)) {
tX = HopRewriteUtils.createTranspose(tX);
} else
tX = tX.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, false);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSigmoid1 (line " + hi.getBeginLine() + ")");
}
// Pattern 2) W * sigmoid(-(Y%*%t(X))) (minus)
if (!appliedPattern && uop.getOp() == OpOp1.SIGMOID && HopRewriteUtils.isBinary(uop.getInput().get(0), OpOp2.MINUS) && uop.getInput().get(0).getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) uop.getInput().get(0).getInput().get(0)) == 0 && uop.getInput().get(0).getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(1).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop Y = uop.getInput().get(0).getInput().get(1).getInput().get(0);
Hop tX = uop.getInput().get(0).getInput().get(1).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(tX)) {
tX = HopRewriteUtils.createTranspose(tX);
} else
tX = tX.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, true);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSigmoid2 (line " + hi.getBeginLine() + ")");
}
// Pattern 3) W * log(sigmoid(Y%*%t(X))) (log)
if (!appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && uop.getInput().get(0).getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop Y = uop.getInput().get(0).getInput().get(0).getInput().get(0);
Hop tX = uop.getInput().get(0).getInput().get(0).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(tX)) {
tX = HopRewriteUtils.createTranspose(tX);
} else
tX = tX.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, false);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSigmoid3 (line " + hi.getBeginLine() + ")");
}
// Pattern 4) W * log(sigmoid(-(Y%*%t(X)))) (log_minus)
if (!appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && HopRewriteUtils.isBinary(uop.getInput().get(0).getInput().get(0), OpOp2.MINUS)) {
BinaryOp bop = (BinaryOp) uop.getInput().get(0).getInput().get(0);
if (bop.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) bop.getInput().get(0)) == 0 && bop.getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(bop.getInput().get(1).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop Y = bop.getInput().get(1).getInput().get(0);
Hop tX = bop.getInput().get(1).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(tX)) {
tX = HopRewriteUtils.createTranspose(tX);
} else
tX = tX.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, true);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSigmoid4 (line " + hi.getBeginLine() + ")");
}
}
}
// relink new hop into original position
if (hnew != null) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
return hi;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyEmptySortOperation.
private static Hop simplifyEmptySortOperation(Hop parent, Hop hi, int pos) {
// order(X, indexreturn=TRUE) -> seq(1,nrow(X),1)
if (hi instanceof ReorgOp && ((ReorgOp) hi).getOp() == ReOrgOp.SORT) {
ReorgOp rhi = (ReorgOp) hi;
Hop input = rhi.getInput().get(0);
if (// empty input
HopRewriteUtils.isEmpty(input)) {
// reorg-operation-specific rewrite
Hop hnew = null;
boolean ixret = false;
if (// index return known
rhi.getInput().get(3) instanceof LiteralOp) {
ixret = HopRewriteUtils.getBooleanValue((LiteralOp) rhi.getInput().get(3));
if (ixret)
hnew = HopRewriteUtils.createSeqDataGenOp(input);
else
hnew = HopRewriteUtils.createDataGenOp(input, 0);
}
// modify dag if one of the above rules applied
if (hnew != null) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
LOG.debug("Applied simplifyEmptySortOperation (indexreturn=" + ixret + ").");
}
}
}
return hi;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyDotProductSum.
/**
* NOTE: dot-product-sum could be also applied to sum(a*b). However, we
* restrict ourselfs to sum(a^2) and transitively sum(a*a) since a general mm
* a%*%b on MR can be also counter-productive (e.g., MMCJ) while tsmm is always
* beneficial.
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
private static Hop simplifyDotProductSum(Hop parent, Hop hi, int pos) {
// w/o materialization of intermediates
if (// sum
hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM && // full aggregate
((AggUnaryOp) hi).getDirection() == Direction.RowCol && // vector (for correctness)
hi.getInput().get(0).getDim2() == 1) {
Hop baLeft = null;
Hop baRight = null;
// check for ^2 w/o multiple consumers
Hop hi2 = hi.getInput().get(0);
// check for sum(v^2), might have been rewritten from sum(v*v)
if (HopRewriteUtils.isBinary(hi2, OpOp2.POW) && hi2.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) hi2.getInput().get(1)) == 2 && // no other consumer than sum
hi2.getParent().size() == 1) {
Hop input = hi2.getInput().get(0);
baLeft = input;
baRight = input;
} else // check for sum(v1*v2), but prevent to rewrite sum(v1*v2*v3) which is later compiled into a ta+* lop
if (// no other consumer than sum
HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1) && hi2.getInput().get(0).getDim2() == 1 && hi2.getInput().get(1).getDim2() == 1 && !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT) && !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT) && (!ALLOW_SUM_PRODUCT_REWRITES || !(// do not rewrite (A^2)*B
HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW) && // let tak+* handle it
hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp && ((LiteralOp) hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2)) && (!ALLOW_SUM_PRODUCT_REWRITES || !(// do not rewrite B*(A^2)
HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW) && // let tak+* handle it
hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp && ((LiteralOp) hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2))) {
baLeft = hi2.getInput().get(0);
baRight = hi2.getInput().get(1);
}
// perform actual rewrite (if necessary)
if (baLeft != null && baRight != null) {
// create new operator chain
ReorgOp trans = HopRewriteUtils.createTranspose(baLeft);
AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, baRight);
UnaryOp cast = HopRewriteUtils.createUnary(mmult, OpOp1.CAST_AS_SCALAR);
// rehang new subdag under parent node
HopRewriteUtils.replaceChildReference(parent, hi, cast, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
hi = cast;
LOG.debug("Applied simplifyDotProductSum.");
}
}
return hi;
}
Aggregations