use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method rule_AlgebraicSimplification.
/**
* Note: X/y -> X * 1/y would be useful because * cheaper than / and sparsesafe; however,
* (1) the results would be not exactly the same (2 rounds instead of 1) and (2) it should
* come before constant folding while the other simplifications should come after constant
* folding. Hence, not applied yet.
*
* @param hop high-level operator
* @param descendFirst true if recursively process children first
*/
private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) {
if (hop.isVisited())
return;
// recursively process children
for (int i = 0; i < hop.getInput().size(); i++) {
Hop hi = hop.getInput().get(i);
// process childs recursively first (to allow roll-up)
if (descendFirst)
// see below
rule_AlgebraicSimplification(hi, descendFirst);
// apply actual simplification rewrites (of childs incl checks)
// e.g., X[,1] -> matrix(0,ru-rl+1,cu-cl+1), if nnz(X)==0
hi = removeEmptyRightIndexing(hop, hi, i);
// e.g., X[,1] -> X, if output == input size
hi = removeUnnecessaryRightIndexing(hop, hi, i);
// e.g., X[,1]=Y -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0 and nnz(Y)==0
hi = removeEmptyLeftIndexing(hop, hi, i);
// e.g., X[,1]=Y -> Y, if output == input dims
hi = removeUnnecessaryLeftIndexing(hop, hi, i);
if (OptimizerUtils.ALLOW_OPERATOR_FUSION)
// e.g., X[,1]=A; X[,2]=B -> X=cbind(A,B), iff ncol(X)==2 and col1/2 lix
hi = fuseLeftIndexingChainToAppend(hop, hi, i);
// e.g., cumsum(X) -> X, if nrow(X)==1;
hi = removeUnnecessaryCumulativeOp(hop, hi, i);
// e.g., matrix(X) -> X, if dims(in)==dims(out); r(X)->X, if 1x1 dims
hi = removeUnnecessaryReorgOperation(hop, hi, i);
// e.g., X*(Y%*%matrix(1,...) -> X*Y, if Y col vector
hi = removeUnnecessaryOuterProduct(hop, hi, i);
// e.g., ifelse(E, A, B) -> A, if E==TRUE or nnz(E)==length(E)
hi = removeUnnecessaryIfElseOperation(hop, hi, i);
if (OptimizerUtils.ALLOW_OPERATOR_FUSION)
// e.g., t(rand(rows=10,cols=1)) -> rand(rows=1,cols=10), if one dim=1
hi = fuseDatagenAndReorgOperation(hop, hi, i);
// e.g., colsums(X) -> sum(X) or X, if col/row vector
hi = simplifyColwiseAggregate(hop, hi, i);
// e.g., rowsums(X) -> sum(X) or X, if row/col vector
hi = simplifyRowwiseAggregate(hop, hi, i);
// e.g., colSums(X*Y) -> t(Y) %*% X, if Y col vector
hi = simplifyColSumsMVMult(hop, hi, i);
// e.g., rowSums(X*Y) -> X %*% t(Y), if Y row vector
hi = simplifyRowSumsMVMult(hop, hi, i);
// e.g., sum(X) -> as.scalar(X), if 1x1 dims
hi = simplifyUnnecessaryAggregate(hop, hi, i);
// e.g., sum(X) -> 0, if nnz(X)==0
hi = simplifyEmptyAggregate(hop, hi, i);
// e.g., round(X) -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0
hi = simplifyEmptyUnaryOperation(hop, hi, i);
// e.g., t(X) -> matrix(0, ncol(X), nrow(X))
hi = simplifyEmptyReorgOperation(hop, hi, i);
// e.g., order(X) -> seq(1, nrow(X)), if nnz(X)==0
hi = simplifyEmptySortOperation(hop, hi, i);
// e.g., X%*%Y -> matrix(0,...), if nnz(Y)==0 | X if Y==matrix(1,1,1)
hi = simplifyEmptyMatrixMult(hop, hi, i);
// e.g., X%*%y -> X if y matrix(1,1,1);
hi = simplifyIdentityRepMatrixMult(hop, hi, i);
// e.g., X%*%y -> X*as.scalar(y), if y is a 1-1 matrix
hi = simplifyScalarMatrixMult(hop, hi, i);
// e.g., diag(X)%*%Y -> X*Y, if ncol(Y)==1 / -> Y*X if ncol(Y)>1
hi = simplifyMatrixMultDiag(hop, hi, i);
// e.g., diag(X%*%Y)->rowSums(X*t(Y)); if col vector
hi = simplifyDiagMatrixMult(hop, hi, i);
// e.g., sum(diag(X)) -> trace(X); if col vector
hi = simplifySumDiagToTrace(hi);
// e.g., diag(X)*7 -> diag(X*7); if col vector
hi = pushdownBinaryOperationOnDiag(hop, hi, i);
// e.g., sum(A+B) -> sum(A)+sum(B); if dims(A)==dims(B)
hi = pushdownSumOnAdditiveBinary(hop, hi, i);
if (OptimizerUtils.ALLOW_OPERATOR_FUSION) {
// e.g., sum(W * (X - U %*% t(V)) ^ 2) -> wsl(X, U, t(V), W, true),
hi = simplifyWeightedSquaredLoss(hop, hi, i);
// e.g., W * sigmoid(Y%*%t(X)) -> wsigmoid(W, Y, t(X), type)
hi = simplifyWeightedSigmoidMMChains(hop, hi, i);
// e.g., t(U) %*% (X/(U%*%t(V))) -> wdivmm(X, U, t(V), left)
hi = simplifyWeightedDivMM(hop, hi, i);
// e.g., sum(X*log(U%*%t(V))) -> wcemm(X, U, t(V))
hi = simplifyWeightedCrossEntropy(hop, hi, i);
// e.g., X*exp(U%*%t(V)) -> wumm(X, U, t(V), exp)
hi = simplifyWeightedUnaryMM(hop, hi, i);
// e.g., sum(v^2) -> t(v)%*%v if ncol(v)==1
hi = simplifyDotProductSum(hop, hi, i);
// e.g., sum(X^2) -> sumSq(X), if ncol(X)>1
hi = fuseSumSquared(hop, hi, i);
// e.g., (X+s*Y) -> (X+*s Y), (X-s*Y) -> (X-*s Y)
hi = fuseAxpyBinaryOperationChain(hop, hi, i);
}
// e.g., (-t(X))%*%y->-(t(X)%*%y), TODO size
hi = reorderMinusMatrixMult(hop, hi, i);
// e.g., sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), if not dot product / wsloss
hi = simplifySumMatrixMult(hop, hi, i);
// e.g., X*Y -> matrix(0,nrow(X), ncol(X)) / X+Y->X / X-Y -> X
hi = simplifyEmptyBinaryOperation(hop, hi, i);
// e.g., X*y -> X*as.scalar(y), if y is a 1-1 matrix
hi = simplifyScalarMVBinaryOperation(hi);
// e.g., sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known
hi = simplifyNnzComputation(hop, hi, i);
// e.g., nrow(X) -> literal(nrow(X)), if nrow known to remove data dependency
hi = simplifyNrowNcolComputation(hop, hi, i);
// e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true)
hi = simplifyTableSeqExpand(hop, hi, i);
// process childs recursively after rewrites (to investigate pattern newly created by rewrites)
if (!descendFirst)
rule_AlgebraicSimplification(hi, descendFirst);
}
hop.setVisited();
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyDotProductSum.
/**
* NOTE: dot-product-sum could be also applied to sum(a*b). However, we
* restrict ourselfs to sum(a^2) and transitively sum(a*a) since a general mm
* a%*%b on MR can be also counter-productive (e.g., MMCJ) while tsmm is always
* beneficial.
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
private static Hop simplifyDotProductSum(Hop parent, Hop hi, int pos) {
// w/o materialization of intermediates
if (// sum
hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM && // full aggregate
((AggUnaryOp) hi).getDirection() == Direction.RowCol && // vector (for correctness)
hi.getInput().get(0).getDim2() == 1) {
Hop baLeft = null;
Hop baRight = null;
// check for ^2 w/o multiple consumers
Hop hi2 = hi.getInput().get(0);
// check for sum(v^2), might have been rewritten from sum(v*v)
if (HopRewriteUtils.isBinary(hi2, OpOp2.POW) && hi2.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) hi2.getInput().get(1)) == 2 && // no other consumer than sum
hi2.getParent().size() == 1) {
Hop input = hi2.getInput().get(0);
baLeft = input;
baRight = input;
} else // check for sum(v1*v2), but prevent to rewrite sum(v1*v2*v3) which is later compiled into a ta+* lop
if (// no other consumer than sum
HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1) && hi2.getInput().get(0).getDim2() == 1 && hi2.getInput().get(1).getDim2() == 1 && !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT) && !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT) && (!ALLOW_SUM_PRODUCT_REWRITES || !(// do not rewrite (A^2)*B
HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW) && // let tak+* handle it
hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp && ((LiteralOp) hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2)) && (!ALLOW_SUM_PRODUCT_REWRITES || !(// do not rewrite B*(A^2)
HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW) && // let tak+* handle it
hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp && ((LiteralOp) hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2))) {
baLeft = hi2.getInput().get(0);
baRight = hi2.getInput().get(1);
}
// perform actual rewrite (if necessary)
if (baLeft != null && baRight != null) {
// create new operator chain
ReorgOp trans = HopRewriteUtils.createTranspose(baLeft);
AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, baRight);
UnaryOp cast = HopRewriteUtils.createUnary(mmult, OpOp1.CAST_AS_SCALAR);
// rehang new subdag under parent node
HopRewriteUtils.replaceChildReference(parent, hi, cast, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
hi = cast;
LOG.debug("Applied simplifyDotProductSum.");
}
}
return hi;
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyDiagMatrixMult.
private static Hop simplifyDiagMatrixMult(Hop parent, Hop hi, int pos) {
if (// diagM2V
hi instanceof ReorgOp && ((ReorgOp) hi).getOp() == ReOrgOp.DIAG && hi.getDim2() == 1) {
Hop hi2 = hi.getInput().get(0);
if (// X%*%Y
HopRewriteUtils.isMatrixMultiply(hi2)) {
Hop left = hi2.getInput().get(0);
Hop right = hi2.getInput().get(1);
// create new operators (incl refresh size inside for transpose)
ReorgOp trans = HopRewriteUtils.createTranspose(right);
BinaryOp mult = HopRewriteUtils.createBinary(left, trans, OpOp2.MULT);
AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(mult, AggOp.SUM, Direction.Row);
// rehang new subdag under parent node
HopRewriteUtils.replaceChildReference(parent, hi, rowSum, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
hi = rowSum;
LOG.debug("Applied simplifyDiagMatrixMult");
}
}
return hi;
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method fuseLeftIndexingChainToAppend.
private static Hop fuseLeftIndexingChainToAppend(Hop parent, Hop hi, int pos) {
boolean applied = false;
// pattern1: X[,1]=A; X[,2]=B -> X=cbind(A,B); matrix / frame
if (// first lix
hi instanceof LeftIndexingOp && HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp) hi) && // second lix
hi.getInput().get(0) instanceof LeftIndexingOp && HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp) hi.getInput().get(0)) && // first lix is single consumer
hi.getInput().get(0).getParent().size() == 1 && // two column matrix
hi.getInput().get(0).getInput().get(0).getDim2() == 2) {
// rhs matrix
Hop input2 = hi.getInput().get(1);
// cl=cu
Hop pred2 = hi.getInput().get(4);
// lhs matrix
Hop input1 = hi.getInput().get(0).getInput().get(1);
// cl=cu
Hop pred1 = hi.getInput().get(0).getInput().get(4);
if (pred1 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) pred1) == 1 && pred2 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) pred2) == 2 && input1.getDataType() != DataType.SCALAR && input2.getDataType() != DataType.SCALAR) {
// create new cbind operation and rewrite inputs
BinaryOp bop = HopRewriteUtils.createBinary(input1, input2, OpOp2.CBIND);
HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
hi = bop;
applied = true;
}
}
// pattern1: X[1,]=A; X[2,]=B -> X=rbind(A,B)
if (// first lix
!applied && hi instanceof LeftIndexingOp && HopRewriteUtils.isFullRowIndexing((LeftIndexingOp) hi) && // second lix
hi.getInput().get(0) instanceof LeftIndexingOp && HopRewriteUtils.isFullRowIndexing((LeftIndexingOp) hi.getInput().get(0)) && // first lix is single consumer
hi.getInput().get(0).getParent().size() == 1 && // two column matrix
hi.getInput().get(0).getInput().get(0).getDim1() == 2) {
// rhs matrix
Hop input2 = hi.getInput().get(1);
// rl=ru
Hop pred2 = hi.getInput().get(2);
// lhs matrix
Hop input1 = hi.getInput().get(0).getInput().get(1);
// rl=ru
Hop pred1 = hi.getInput().get(0).getInput().get(2);
if (pred1 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) pred1) == 1 && pred2 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) pred2) == 2 && input1.getDataType() != DataType.SCALAR && input2.getDataType() != DataType.SCALAR) {
// create new cbind operation and rewrite inputs
BinaryOp bop = HopRewriteUtils.createBinary(input1, input2, OpOp2.RBIND);
HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
hi = bop;
applied = true;
LOG.debug("Applied fuseLeftIndexingChainToAppend2 (line " + hi.getBeginLine() + ")");
}
}
return hi;
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyNrowNcolComputation.
private static Hop simplifyNrowNcolComputation(Hop parent, Hop hi, int pos) {
// even if the intermediate is otherwise not required, e.g., when part of a fused operator)
if (hi instanceof UnaryOp) {
if (((UnaryOp) hi).getOp() == OpOp1.NROW && hi.getInput().get(0).rowsKnown()) {
Hop hnew = new LiteralOp(hi.getInput().get(0).getDim1());
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos, false);
HopRewriteUtils.cleanupUnreferenced(hi);
LOG.debug("Applied simplifyNrowComputation nrow(" + hi.getHopID() + ") -> " + hnew.getName() + " (line " + hi.getBeginLine() + ").");
hi = hnew;
} else if (((UnaryOp) hi).getOp() == OpOp1.NCOL && hi.getInput().get(0).colsKnown()) {
Hop hnew = new LiteralOp(hi.getInput().get(0).getDim2());
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos, false);
HopRewriteUtils.cleanupUnreferenced(hi);
LOG.debug("Applied simplifyNcolComputation ncol(" + hi.getHopID() + ") -> " + hnew.getName() + " (line " + hi.getBeginLine() + ").");
hi = hnew;
}
}
return hi;
}
Aggregations