use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyScalarMatrixMult.
private static Hop simplifyScalarMatrixMult(Hop parent, Hop hi, int pos) {
if (// X%*%Y
HopRewriteUtils.isMatrixMultiply(hi)) {
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
// y %*% X -> as.scalar(y) * X
if (// scalar left
HopRewriteUtils.isDimsKnown(left) && left.getDim1() == 1 && left.getDim2() == 1) {
UnaryOp cast = HopRewriteUtils.createUnary(left, OpOp1.CAST_AS_SCALAR);
BinaryOp mult = HopRewriteUtils.createBinary(cast, right, OpOp2.MULT);
// add mult to parent
HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = mult;
LOG.debug("Applied simplifyScalarMatrixMult1");
} else // X %*% y -> X * as.scalar(y)
if (// scalar right
HopRewriteUtils.isDimsKnown(right) && right.getDim1() == 1 && right.getDim2() == 1) {
UnaryOp cast = HopRewriteUtils.createUnary(right, OpOp1.CAST_AS_SCALAR);
BinaryOp mult = HopRewriteUtils.createBinary(cast, left, OpOp2.MULT);
// add mult to parent
HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = mult;
LOG.debug("Applied simplifyScalarMatrixMult2");
}
}
return hi;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSigmoidMMChains.
private static Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos) {
Hop hnew = null;
if (// all patterns subrooted by W *
HopRewriteUtils.isBinary(hi, OpOp2.MULT) && // not applied for vector-vector mult
hi.getDim2() > 1 && // prevent mv
HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && hi.getInput().get(0).getDataType() == DataType.MATRIX && // sigmoid/log
hi.getInput().get(1) instanceof UnaryOp) {
UnaryOp uop = (UnaryOp) hi.getInput().get(1);
boolean appliedPattern = false;
// Pattern 1) W * sigmoid(Y%*%t(X)) (basic)
if (uop.getOp() == OpOp1.SIGMOID && uop.getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop Y = uop.getInput().get(0).getInput().get(0);
Hop tX = uop.getInput().get(0).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(tX)) {
tX = HopRewriteUtils.createTranspose(tX);
} else
tX = tX.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, false);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSigmoid1 (line " + hi.getBeginLine() + ")");
}
// Pattern 2) W * sigmoid(-(Y%*%t(X))) (minus)
if (!appliedPattern && uop.getOp() == OpOp1.SIGMOID && HopRewriteUtils.isBinary(uop.getInput().get(0), OpOp2.MINUS) && uop.getInput().get(0).getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) uop.getInput().get(0).getInput().get(0)) == 0 && uop.getInput().get(0).getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(1).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop Y = uop.getInput().get(0).getInput().get(1).getInput().get(0);
Hop tX = uop.getInput().get(0).getInput().get(1).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(tX)) {
tX = HopRewriteUtils.createTranspose(tX);
} else
tX = tX.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, true);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSigmoid2 (line " + hi.getBeginLine() + ")");
}
// Pattern 3) W * log(sigmoid(Y%*%t(X))) (log)
if (!appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && uop.getInput().get(0).getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop Y = uop.getInput().get(0).getInput().get(0).getInput().get(0);
Hop tX = uop.getInput().get(0).getInput().get(0).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(tX)) {
tX = HopRewriteUtils.createTranspose(tX);
} else
tX = tX.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, false);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSigmoid3 (line " + hi.getBeginLine() + ")");
}
// Pattern 4) W * log(sigmoid(-(Y%*%t(X)))) (log_minus)
if (!appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && HopRewriteUtils.isBinary(uop.getInput().get(0).getInput().get(0), OpOp2.MINUS)) {
BinaryOp bop = (BinaryOp) uop.getInput().get(0).getInput().get(0);
if (bop.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) bop.getInput().get(0)) == 0 && bop.getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(bop.getInput().get(1).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop Y = bop.getInput().get(1).getInput().get(0);
Hop tX = bop.getInput().get(1).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(tX)) {
tX = HopRewriteUtils.createTranspose(tX);
} else
tX = tX.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, true);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSigmoid4 (line " + hi.getBeginLine() + ")");
}
}
}
// relink new hop into original position
if (hnew != null) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
return hi;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyDiagMatrixMult.
private static Hop simplifyDiagMatrixMult(Hop parent, Hop hi, int pos) {
if (// diagM2V
hi instanceof ReorgOp && ((ReorgOp) hi).getOp() == ReOrgOp.DIAG && hi.getDim2() == 1) {
Hop hi2 = hi.getInput().get(0);
if (// X%*%Y
HopRewriteUtils.isMatrixMultiply(hi2)) {
Hop left = hi2.getInput().get(0);
Hop right = hi2.getInput().get(1);
// create new operators (incl refresh size inside for transpose)
ReorgOp trans = HopRewriteUtils.createTranspose(right);
BinaryOp mult = HopRewriteUtils.createBinary(left, trans, OpOp2.MULT);
AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(mult, AggOp.SUM, Direction.Row);
// rehang new subdag under parent node
HopRewriteUtils.replaceChildReference(parent, hi, rowSum, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
hi = rowSum;
LOG.debug("Applied simplifyDiagMatrixMult");
}
}
return hi;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method fuseLeftIndexingChainToAppend.
private static Hop fuseLeftIndexingChainToAppend(Hop parent, Hop hi, int pos) {
boolean applied = false;
// pattern1: X[,1]=A; X[,2]=B -> X=cbind(A,B); matrix / frame
if (// first lix
hi instanceof LeftIndexingOp && HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp) hi) && // second lix
hi.getInput().get(0) instanceof LeftIndexingOp && HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp) hi.getInput().get(0)) && // first lix is single consumer
hi.getInput().get(0).getParent().size() == 1 && // two column matrix
hi.getInput().get(0).getInput().get(0).getDim2() == 2) {
// rhs matrix
Hop input2 = hi.getInput().get(1);
// cl=cu
Hop pred2 = hi.getInput().get(4);
// lhs matrix
Hop input1 = hi.getInput().get(0).getInput().get(1);
// cl=cu
Hop pred1 = hi.getInput().get(0).getInput().get(4);
if (pred1 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) pred1) == 1 && pred2 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) pred2) == 2 && input1.getDataType() != DataType.SCALAR && input2.getDataType() != DataType.SCALAR) {
// create new cbind operation and rewrite inputs
BinaryOp bop = HopRewriteUtils.createBinary(input1, input2, OpOp2.CBIND);
HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
hi = bop;
applied = true;
}
}
// pattern1: X[1,]=A; X[2,]=B -> X=rbind(A,B)
if (// first lix
!applied && hi instanceof LeftIndexingOp && HopRewriteUtils.isFullRowIndexing((LeftIndexingOp) hi) && // second lix
hi.getInput().get(0) instanceof LeftIndexingOp && HopRewriteUtils.isFullRowIndexing((LeftIndexingOp) hi.getInput().get(0)) && // first lix is single consumer
hi.getInput().get(0).getParent().size() == 1 && // two column matrix
hi.getInput().get(0).getInput().get(0).getDim1() == 2) {
// rhs matrix
Hop input2 = hi.getInput().get(1);
// rl=ru
Hop pred2 = hi.getInput().get(2);
// lhs matrix
Hop input1 = hi.getInput().get(0).getInput().get(1);
// rl=ru
Hop pred1 = hi.getInput().get(0).getInput().get(2);
if (pred1 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) pred1) == 1 && pred2 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) pred2) == 2 && input1.getDataType() != DataType.SCALAR && input2.getDataType() != DataType.SCALAR) {
// create new cbind operation and rewrite inputs
BinaryOp bop = HopRewriteUtils.createBinary(input1, input2, OpOp2.RBIND);
HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
hi = bop;
applied = true;
LOG.debug("Applied fuseLeftIndexingChainToAppend2 (line " + hi.getBeginLine() + ")");
}
}
return hi;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method removeUnnecessaryMinus.
private static Hop removeUnnecessaryMinus(Hop parent, Hop hi, int pos) {
if (hi.getDataType() == DataType.MATRIX && hi instanceof BinaryOp && // first minus
((BinaryOp) hi).getOp() == OpOp2.MINUS && hi.getInput().get(0) instanceof LiteralOp && ((LiteralOp) hi.getInput().get(0)).getDoubleValue() == 0) {
Hop hi2 = hi.getInput().get(1);
if (hi2.getDataType() == DataType.MATRIX && hi2 instanceof BinaryOp && // second minus
((BinaryOp) hi2).getOp() == OpOp2.MINUS && hi2.getInput().get(0) instanceof LiteralOp && ((LiteralOp) hi2.getInput().get(0)).getDoubleValue() == 0) {
Hop hi3 = hi2.getInput().get(1);
// remove unnecessary chain of -(-())
HopRewriteUtils.replaceChildReference(parent, hi, hi3, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
hi = hi3;
LOG.debug("Applied removeUnecessaryMinus");
}
}
return hi;
}
Aggregations