Search in sources :

Example 86 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method rule_AlgebraicSimplification.

/**
 * Note: X/y -> X * 1/y would be useful because * cheaper than / and sparsesafe; however,
 * (1) the results would be not exactly the same (2 rounds instead of 1) and (2) it should
 * come before constant folding while the other simplifications should come after constant
 * folding. Hence, not applied yet.
 *
 * @param hop high-level operator
 * @param descendFirst true if recursively process children first
 */
private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) {
    if (hop.isVisited())
        return;
    // recursively process children
    for (int i = 0; i < hop.getInput().size(); i++) {
        Hop hi = hop.getInput().get(i);
        // process childs recursively first (to allow roll-up)
        if (descendFirst)
            // see below
            rule_AlgebraicSimplification(hi, descendFirst);
        // apply actual simplification rewrites (of childs incl checks)
        // e.g., X[,1] -> matrix(0,ru-rl+1,cu-cl+1), if nnz(X)==0
        hi = removeEmptyRightIndexing(hop, hi, i);
        // e.g., X[,1] -> X, if output == input size
        hi = removeUnnecessaryRightIndexing(hop, hi, i);
        // e.g., X[,1]=Y -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0 and nnz(Y)==0
        hi = removeEmptyLeftIndexing(hop, hi, i);
        // e.g., X[,1]=Y -> Y, if output == input dims
        hi = removeUnnecessaryLeftIndexing(hop, hi, i);
        if (OptimizerUtils.ALLOW_OPERATOR_FUSION)
            // e.g., X[,1]=A; X[,2]=B -> X=cbind(A,B), iff ncol(X)==2 and col1/2 lix
            hi = fuseLeftIndexingChainToAppend(hop, hi, i);
        // e.g., cumsum(X) -> X, if nrow(X)==1;
        hi = removeUnnecessaryCumulativeOp(hop, hi, i);
        // e.g., matrix(X) -> X, if dims(in)==dims(out); r(X)->X, if 1x1 dims
        hi = removeUnnecessaryReorgOperation(hop, hi, i);
        // e.g., X*(Y%*%matrix(1,...) -> X*Y, if Y col vector
        hi = removeUnnecessaryOuterProduct(hop, hi, i);
        // e.g., ifelse(E, A, B) -> A, if E==TRUE or nnz(E)==length(E)
        hi = removeUnnecessaryIfElseOperation(hop, hi, i);
        if (OptimizerUtils.ALLOW_OPERATOR_FUSION)
            // e.g., t(rand(rows=10,cols=1)) -> rand(rows=1,cols=10), if one dim=1
            hi = fuseDatagenAndReorgOperation(hop, hi, i);
        // e.g., colsums(X) -> sum(X) or X, if col/row vector
        hi = simplifyColwiseAggregate(hop, hi, i);
        // e.g., rowsums(X) -> sum(X) or X, if row/col vector
        hi = simplifyRowwiseAggregate(hop, hi, i);
        // e.g., colSums(X*Y) -> t(Y) %*% X, if Y col vector
        hi = simplifyColSumsMVMult(hop, hi, i);
        // e.g., rowSums(X*Y) -> X %*% t(Y), if Y row vector
        hi = simplifyRowSumsMVMult(hop, hi, i);
        // e.g., sum(X) -> as.scalar(X), if 1x1 dims
        hi = simplifyUnnecessaryAggregate(hop, hi, i);
        // e.g., sum(X) -> 0, if nnz(X)==0
        hi = simplifyEmptyAggregate(hop, hi, i);
        // e.g., round(X) -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0
        hi = simplifyEmptyUnaryOperation(hop, hi, i);
        // e.g., t(X) -> matrix(0, ncol(X), nrow(X))
        hi = simplifyEmptyReorgOperation(hop, hi, i);
        // e.g., order(X) -> seq(1, nrow(X)), if nnz(X)==0
        hi = simplifyEmptySortOperation(hop, hi, i);
        // e.g., X%*%Y -> matrix(0,...), if nnz(Y)==0 | X if Y==matrix(1,1,1)
        hi = simplifyEmptyMatrixMult(hop, hi, i);
        // e.g., X%*%y -> X if y matrix(1,1,1);
        hi = simplifyIdentityRepMatrixMult(hop, hi, i);
        // e.g., X%*%y -> X*as.scalar(y), if y is a 1-1 matrix
        hi = simplifyScalarMatrixMult(hop, hi, i);
        // e.g., diag(X)%*%Y -> X*Y, if ncol(Y)==1 / -> Y*X if ncol(Y)>1
        hi = simplifyMatrixMultDiag(hop, hi, i);
        // e.g., diag(X%*%Y)->rowSums(X*t(Y)); if col vector
        hi = simplifyDiagMatrixMult(hop, hi, i);
        // e.g., sum(diag(X)) -> trace(X); if col vector
        hi = simplifySumDiagToTrace(hi);
        // e.g., diag(X)*7 -> diag(X*7); if col vector
        hi = pushdownBinaryOperationOnDiag(hop, hi, i);
        // e.g., sum(A+B) -> sum(A)+sum(B); if dims(A)==dims(B)
        hi = pushdownSumOnAdditiveBinary(hop, hi, i);
        if (OptimizerUtils.ALLOW_OPERATOR_FUSION) {
            // e.g., sum(W * (X - U %*% t(V)) ^ 2) -> wsl(X, U, t(V), W, true),
            hi = simplifyWeightedSquaredLoss(hop, hi, i);
            // e.g., W * sigmoid(Y%*%t(X)) -> wsigmoid(W, Y, t(X), type)
            hi = simplifyWeightedSigmoidMMChains(hop, hi, i);
            // e.g., t(U) %*% (X/(U%*%t(V))) -> wdivmm(X, U, t(V), left)
            hi = simplifyWeightedDivMM(hop, hi, i);
            // e.g., sum(X*log(U%*%t(V))) -> wcemm(X, U, t(V))
            hi = simplifyWeightedCrossEntropy(hop, hi, i);
            // e.g., X*exp(U%*%t(V)) -> wumm(X, U, t(V), exp)
            hi = simplifyWeightedUnaryMM(hop, hi, i);
            // e.g., sum(v^2) -> t(v)%*%v if ncol(v)==1
            hi = simplifyDotProductSum(hop, hi, i);
            // e.g., sum(X^2) -> sumSq(X), if ncol(X)>1
            hi = fuseSumSquared(hop, hi, i);
            // e.g., (X+s*Y) -> (X+*s Y), (X-s*Y) -> (X-*s Y)
            hi = fuseAxpyBinaryOperationChain(hop, hi, i);
        }
        // e.g., (-t(X))%*%y->-(t(X)%*%y), TODO size
        hi = reorderMinusMatrixMult(hop, hi, i);
        // e.g., sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), if not dot product / wsloss
        hi = simplifySumMatrixMult(hop, hi, i);
        // e.g., X*Y -> matrix(0,nrow(X), ncol(X)) / X+Y->X / X-Y -> X
        hi = simplifyEmptyBinaryOperation(hop, hi, i);
        // e.g., X*y -> X*as.scalar(y), if y is a 1-1 matrix
        hi = simplifyScalarMVBinaryOperation(hi);
        // e.g., sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known
        hi = simplifyNnzComputation(hop, hi, i);
        // e.g., nrow(X) -> literal(nrow(X)), if nrow known to remove data dependency
        hi = simplifyNrowNcolComputation(hop, hi, i);
        // e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true)
        hi = simplifyTableSeqExpand(hop, hi, i);
        // process childs recursively after rewrites (to investigate pattern newly created by rewrites)
        if (!descendFirst)
            rule_AlgebraicSimplification(hi, descendFirst);
    }
    hop.setVisited();
}
Also used : Hop(org.apache.sysml.hops.Hop)

Example 87 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyDotProductSum.

/**
 * NOTE: dot-product-sum could be also applied to sum(a*b). However, we
 * restrict ourselfs to sum(a^2) and transitively sum(a*a) since a general mm
 * a%*%b on MR can be also counter-productive (e.g., MMCJ) while tsmm is always
 * beneficial.
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop simplifyDotProductSum(Hop parent, Hop hi, int pos) {
    // w/o materialization of intermediates
    if (// sum
    hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM && // full aggregate
    ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // vector (for correctness)
    hi.getInput().get(0).getDim2() == 1) {
        Hop baLeft = null;
        Hop baRight = null;
        // check for ^2 w/o multiple consumers
        Hop hi2 = hi.getInput().get(0);
        // check for sum(v^2), might have been rewritten from sum(v*v)
        if (HopRewriteUtils.isBinary(hi2, OpOp2.POW) && hi2.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) hi2.getInput().get(1)) == 2 && // no other consumer than sum
        hi2.getParent().size() == 1) {
            Hop input = hi2.getInput().get(0);
            baLeft = input;
            baRight = input;
        } else // check for sum(v1*v2), but prevent to rewrite sum(v1*v2*v3) which is later compiled into a ta+* lop
        if (// no other consumer than sum
        HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1) && hi2.getInput().get(0).getDim2() == 1 && hi2.getInput().get(1).getDim2() == 1 && !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT) && !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT) && (!ALLOW_SUM_PRODUCT_REWRITES || !(// do not rewrite (A^2)*B
        HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW) && // let tak+* handle it
        hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp && ((LiteralOp) hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2)) && (!ALLOW_SUM_PRODUCT_REWRITES || !(// do not rewrite B*(A^2)
        HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW) && // let tak+* handle it
        hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp && ((LiteralOp) hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2))) {
            baLeft = hi2.getInput().get(0);
            baRight = hi2.getInput().get(1);
        }
        // perform actual rewrite (if necessary)
        if (baLeft != null && baRight != null) {
            // create new operator chain
            ReorgOp trans = HopRewriteUtils.createTranspose(baLeft);
            AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, baRight);
            UnaryOp cast = HopRewriteUtils.createUnary(mmult, OpOp1.CAST_AS_SCALAR);
            // rehang new subdag under parent node
            HopRewriteUtils.replaceChildReference(parent, hi, cast, pos);
            HopRewriteUtils.cleanupUnreferenced(hi, hi2);
            hi = cast;
            LOG.debug("Applied simplifyDotProductSum.");
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 88 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyDiagMatrixMult.

private static Hop simplifyDiagMatrixMult(Hop parent, Hop hi, int pos) {
    if (// diagM2V
    hi instanceof ReorgOp && ((ReorgOp) hi).getOp() == ReOrgOp.DIAG && hi.getDim2() == 1) {
        Hop hi2 = hi.getInput().get(0);
        if (// X%*%Y
        HopRewriteUtils.isMatrixMultiply(hi2)) {
            Hop left = hi2.getInput().get(0);
            Hop right = hi2.getInput().get(1);
            // create new operators (incl refresh size inside for transpose)
            ReorgOp trans = HopRewriteUtils.createTranspose(right);
            BinaryOp mult = HopRewriteUtils.createBinary(left, trans, OpOp2.MULT);
            AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(mult, AggOp.SUM, Direction.Row);
            // rehang new subdag under parent node
            HopRewriteUtils.replaceChildReference(parent, hi, rowSum, pos);
            HopRewriteUtils.cleanupUnreferenced(hi, hi2);
            hi = rowSum;
            LOG.debug("Applied simplifyDiagMatrixMult");
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) ReorgOp(org.apache.sysml.hops.ReorgOp) Hop(org.apache.sysml.hops.Hop) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 89 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method fuseLeftIndexingChainToAppend.

private static Hop fuseLeftIndexingChainToAppend(Hop parent, Hop hi, int pos) {
    boolean applied = false;
    // pattern1: X[,1]=A; X[,2]=B -> X=cbind(A,B); matrix / frame
    if (// first lix
    hi instanceof LeftIndexingOp && HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp) hi) && // second lix
    hi.getInput().get(0) instanceof LeftIndexingOp && HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp) hi.getInput().get(0)) && // first lix is single consumer
    hi.getInput().get(0).getParent().size() == 1 && // two column matrix
    hi.getInput().get(0).getInput().get(0).getDim2() == 2) {
        // rhs matrix
        Hop input2 = hi.getInput().get(1);
        // cl=cu
        Hop pred2 = hi.getInput().get(4);
        // lhs matrix
        Hop input1 = hi.getInput().get(0).getInput().get(1);
        // cl=cu
        Hop pred1 = hi.getInput().get(0).getInput().get(4);
        if (pred1 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) pred1) == 1 && pred2 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) pred2) == 2 && input1.getDataType() != DataType.SCALAR && input2.getDataType() != DataType.SCALAR) {
            // create new cbind operation and rewrite inputs
            BinaryOp bop = HopRewriteUtils.createBinary(input1, input2, OpOp2.CBIND);
            HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
            hi = bop;
            applied = true;
        }
    }
    // pattern1: X[1,]=A; X[2,]=B -> X=rbind(A,B)
    if (// first lix
    !applied && hi instanceof LeftIndexingOp && HopRewriteUtils.isFullRowIndexing((LeftIndexingOp) hi) && // second lix
    hi.getInput().get(0) instanceof LeftIndexingOp && HopRewriteUtils.isFullRowIndexing((LeftIndexingOp) hi.getInput().get(0)) && // first lix is single consumer
    hi.getInput().get(0).getParent().size() == 1 && // two column matrix
    hi.getInput().get(0).getInput().get(0).getDim1() == 2) {
        // rhs matrix
        Hop input2 = hi.getInput().get(1);
        // rl=ru
        Hop pred2 = hi.getInput().get(2);
        // lhs matrix
        Hop input1 = hi.getInput().get(0).getInput().get(1);
        // rl=ru
        Hop pred1 = hi.getInput().get(0).getInput().get(2);
        if (pred1 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) pred1) == 1 && pred2 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) pred2) == 2 && input1.getDataType() != DataType.SCALAR && input2.getDataType() != DataType.SCALAR) {
            // create new cbind operation and rewrite inputs
            BinaryOp bop = HopRewriteUtils.createBinary(input1, input2, OpOp2.RBIND);
            HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
            hi = bop;
            applied = true;
            LOG.debug("Applied fuseLeftIndexingChainToAppend2 (line " + hi.getBeginLine() + ")");
        }
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 90 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyNrowNcolComputation.

private static Hop simplifyNrowNcolComputation(Hop parent, Hop hi, int pos) {
    // even if the intermediate is otherwise not required, e.g., when part of a fused operator)
    if (hi instanceof UnaryOp) {
        if (((UnaryOp) hi).getOp() == OpOp1.NROW && hi.getInput().get(0).rowsKnown()) {
            Hop hnew = new LiteralOp(hi.getInput().get(0).getDim1());
            HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos, false);
            HopRewriteUtils.cleanupUnreferenced(hi);
            LOG.debug("Applied simplifyNrowComputation nrow(" + hi.getHopID() + ") -> " + hnew.getName() + " (line " + hi.getBeginLine() + ").");
            hi = hnew;
        } else if (((UnaryOp) hi).getOp() == OpOp1.NCOL && hi.getInput().get(0).colsKnown()) {
            Hop hnew = new LiteralOp(hi.getInput().get(0).getDim2());
            HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos, false);
            HopRewriteUtils.cleanupUnreferenced(hi);
            LOG.debug("Applied simplifyNcolComputation ncol(" + hi.getHopID() + ") -> " + hnew.getName() + " (line " + hi.getBeginLine() + ").");
            hi = hnew;
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Aggregations

Hop (org.apache.sysml.hops.Hop)307 LiteralOp (org.apache.sysml.hops.LiteralOp)94 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)65 BinaryOp (org.apache.sysml.hops.BinaryOp)63 ArrayList (java.util.ArrayList)61 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)61 HashMap (java.util.HashMap)44 DataOp (org.apache.sysml.hops.DataOp)41 UnaryOp (org.apache.sysml.hops.UnaryOp)41 HashSet (java.util.HashSet)39 ReorgOp (org.apache.sysml.hops.ReorgOp)32 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)28 StatementBlock (org.apache.sysml.parser.StatementBlock)28 IndexingOp (org.apache.sysml.hops.IndexingOp)24 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)23 WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)23 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)22 DataGenOp (org.apache.sysml.hops.DataGenOp)21 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)21 HopsException (org.apache.sysml.hops.HopsException)18