use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSigmoidMMChains.
private static Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos) {
Hop hnew = null;
if (// all patterns subrooted by W *
HopRewriteUtils.isBinary(hi, OpOp2.MULT) && // not applied for vector-vector mult
hi.getDim2() > 1 && // prevent mv
HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && hi.getInput().get(0).getDataType() == DataType.MATRIX && // sigmoid/log
hi.getInput().get(1) instanceof UnaryOp) {
UnaryOp uop = (UnaryOp) hi.getInput().get(1);
boolean appliedPattern = false;
// Pattern 1) W * sigmoid(Y%*%t(X)) (basic)
if (uop.getOp() == OpOp1.SIGMOID && uop.getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop Y = uop.getInput().get(0).getInput().get(0);
Hop tX = uop.getInput().get(0).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(tX)) {
tX = HopRewriteUtils.createTranspose(tX);
} else
tX = tX.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, false);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSigmoid1 (line " + hi.getBeginLine() + ")");
}
// Pattern 2) W * sigmoid(-(Y%*%t(X))) (minus)
if (!appliedPattern && uop.getOp() == OpOp1.SIGMOID && HopRewriteUtils.isBinary(uop.getInput().get(0), OpOp2.MINUS) && uop.getInput().get(0).getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) uop.getInput().get(0).getInput().get(0)) == 0 && uop.getInput().get(0).getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(1).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop Y = uop.getInput().get(0).getInput().get(1).getInput().get(0);
Hop tX = uop.getInput().get(0).getInput().get(1).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(tX)) {
tX = HopRewriteUtils.createTranspose(tX);
} else
tX = tX.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, true);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSigmoid2 (line " + hi.getBeginLine() + ")");
}
// Pattern 3) W * log(sigmoid(Y%*%t(X))) (log)
if (!appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && uop.getInput().get(0).getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop Y = uop.getInput().get(0).getInput().get(0).getInput().get(0);
Hop tX = uop.getInput().get(0).getInput().get(0).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(tX)) {
tX = HopRewriteUtils.createTranspose(tX);
} else
tX = tX.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, false);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSigmoid3 (line " + hi.getBeginLine() + ")");
}
// Pattern 4) W * log(sigmoid(-(Y%*%t(X)))) (log_minus)
if (!appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && HopRewriteUtils.isBinary(uop.getInput().get(0).getInput().get(0), OpOp2.MINUS)) {
BinaryOp bop = (BinaryOp) uop.getInput().get(0).getInput().get(0);
if (bop.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) bop.getInput().get(0)) == 0 && bop.getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(bop.getInput().get(1).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop Y = bop.getInput().get(1).getInput().get(0);
Hop tX = bop.getInput().get(1).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(tX)) {
tX = HopRewriteUtils.createTranspose(tX);
} else
tX = tX.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, true);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSigmoid4 (line " + hi.getBeginLine() + ")");
}
}
}
// relink new hop into original position
if (hnew != null) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
return hi;
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyIdentityRepMatrixMult.
private static Hop simplifyIdentityRepMatrixMult(Hop parent, Hop hi, int pos) {
if (// X%*%Y -> X, if y is matrix(1,1,1)
HopRewriteUtils.isMatrixMultiply(hi)) {
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
// X %*% y -> X
if (// scalar right
HopRewriteUtils.isDimsKnown(right) && right.getDim1() == 1 && right.getDim2() == 1 && right instanceof DataGenOp && ((DataGenOp) right).getOp() == DataGenMethod.RAND && // matrix(1,)
((DataGenOp) right).hasConstantValue(1.0)) {
HopRewriteUtils.replaceChildReference(parent, hi, left, pos);
hi = left;
LOG.debug("Applied simplifyIdentiyMatrixMult");
}
}
return hi;
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyMatrixMultDiag.
private static Hop simplifyMatrixMultDiag(Hop parent, Hop hi, int pos) {
Hop hnew = null;
if (// X%*%Y
HopRewriteUtils.isMatrixMultiply(hi)) {
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
// previously rep required for the second case: diag(X) %*% Y -> (X%*%ones) * Y
if (// left diag
left instanceof ReorgOp && ((ReorgOp) left).getOp() == ReOrgOp.DIAG && HopRewriteUtils.isDimsKnown(left) && // diagV2M
left.getDim2() > 1) {
if (// right column vector
right.getDim2() == 1) {
// create binary operation over input and right
// diag input
Hop input = left.getInput().get(0);
hnew = HopRewriteUtils.createBinary(input, right, OpOp2.MULT);
LOG.debug("Applied simplifyMatrixMultDiag1");
} else if (// multi column vector
right.getDim2() > 1) {
// create binary operation over input and right; in contrast to above rewrite,
// we need to switch the order because MV binary cell operations require vector on the right
// diag input
Hop input = left.getInput().get(0);
hnew = HopRewriteUtils.createBinary(right, input, OpOp2.MULT);
// NOTE: previously to MV binary cell operations we replicated the left
// (if moderate number of columns: 2), but this is no longer required
LOG.debug("Applied simplifyMatrixMultDiag2");
}
}
// notes: similar rewrites would be possible for the right side as well, just transposed into the right alignment
}
// if one of the above rewrites applied
if (hnew != null) {
// add mult to parent
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = hnew;
}
return hi;
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyEmptySortOperation.
private static Hop simplifyEmptySortOperation(Hop parent, Hop hi, int pos) {
// order(X, indexreturn=TRUE) -> seq(1,nrow(X),1)
if (hi instanceof ReorgOp && ((ReorgOp) hi).getOp() == ReOrgOp.SORT) {
ReorgOp rhi = (ReorgOp) hi;
Hop input = rhi.getInput().get(0);
if (// empty input
HopRewriteUtils.isEmpty(input)) {
// reorg-operation-specific rewrite
Hop hnew = null;
boolean ixret = false;
if (// index return known
rhi.getInput().get(3) instanceof LiteralOp) {
ixret = HopRewriteUtils.getBooleanValue((LiteralOp) rhi.getInput().get(3));
if (ixret)
hnew = HopRewriteUtils.createSeqDataGenOp(input);
else
hnew = HopRewriteUtils.createDataGenOp(input, 0);
}
// modify dag if one of the above rules applied
if (hnew != null) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
LOG.debug("Applied simplifyEmptySortOperation (indexreturn=" + ixret + ").");
}
}
}
return hi;
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyColSumsMVMult.
private static Hop simplifyColSumsMVMult(Hop parent, Hop hi, int pos) {
// removed by other rewrite if unnecessary, i.e., if Y==t(Z)
if (hi instanceof AggUnaryOp) {
AggUnaryOp uhi = (AggUnaryOp) hi;
Hop input = uhi.getInput().get(0);
if (// colsums
uhi.getOp() == AggOp.SUM && uhi.getDirection() == Direction.Col && // b(*)
HopRewriteUtils.isBinary(input, OpOp2.MULT)) {
Hop left = input.getInput().get(0);
Hop right = input.getInput().get(1);
if (left.getDim1() > 1 && left.getDim2() > 1 && right.getDim1() > 1 && // MV (col vector)
right.getDim2() == 1) {
// create new operators
ReorgOp trans = HopRewriteUtils.createTranspose(right);
AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, left);
// relink new child
HopRewriteUtils.replaceChildReference(parent, hi, mmult, pos);
HopRewriteUtils.cleanupUnreferenced(uhi, input);
hi = mmult;
LOG.debug("Applied simplifyColSumsMVMult");
}
}
}
return hi;
}
Aggregations