Search in sources :

Example 26 with LiteralOp

use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method fuseAxpyBinaryOperationChain.

private static Hop fuseAxpyBinaryOperationChain(Hop parent, Hop hi, int pos) {
    // patterns: (a) X + s*Y -> X +* sY, (b) s*Y+X -> X +* sY, (c) X - s*Y -> X -* sY
    if (hi instanceof BinaryOp && !((BinaryOp) hi).isOuterVectorOperator() && (((BinaryOp) hi).getOp() == OpOp2.PLUS || ((BinaryOp) hi).getOp() == OpOp2.MINUS)) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        Hop ternop = null;
        // pattern (a) X + s*Y -> X +* sY
        if (bop.getOp() == OpOp2.PLUS && left.getDataType() == DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(right) && HopRewriteUtils.isEqualSize(left, right) && // single consumer s*Y
        right.getParent().size() == 1) {
            Hop smid = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 0 : 1);
            Hop mright = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 1 : 0);
            ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) smid) == 0) ? left : HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.PLUS_MULT);
            LOG.debug("Applied fuseAxpyBinaryOperationChain1. (line " + hi.getBeginLine() + ")");
        } else // pattern (b) s*Y + X -> X +* sY
        if (bop.getOp() == OpOp2.PLUS && right.getDataType() == DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(left) && HopRewriteUtils.isEqualSize(left, right) && // single consumer s*Y
        left.getParent().size() == 1) {
            Hop smid = left.getInput().get((left.getInput().get(0).getDataType() == DataType.SCALAR) ? 0 : 1);
            Hop mright = left.getInput().get((left.getInput().get(0).getDataType() == DataType.SCALAR) ? 1 : 0);
            ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) smid) == 0) ? right : HopRewriteUtils.createTernaryOp(right, smid, mright, OpOp3.PLUS_MULT);
            LOG.debug("Applied fuseAxpyBinaryOperationChain2. (line " + hi.getBeginLine() + ")");
        } else // pattern (c) X - s*Y -> X -* sY
        if (bop.getOp() == OpOp2.MINUS && left.getDataType() == DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(right) && HopRewriteUtils.isEqualSize(left, right) && // single consumer s*Y
        right.getParent().size() == 1) {
            Hop smid = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 0 : 1);
            Hop mright = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 1 : 0);
            ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) smid) == 0) ? left : HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.MINUS_MULT);
            LOG.debug("Applied fuseAxpyBinaryOperationChain3. (line " + hi.getBeginLine() + ")");
        }
        // rewire parent-child operators if rewrite applied
        if (ternop != null) {
            HopRewriteUtils.replaceChildReference(parent, hi, ternop, pos);
            hi = ternop;
        }
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 27 with LiteralOp

use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyEmptyReorgOperation.

private static Hop simplifyEmptyReorgOperation(Hop parent, Hop hi, int pos) {
    if (hi instanceof ReorgOp) {
        ReorgOp rhi = (ReorgOp) hi;
        Hop input = rhi.getInput().get(0);
        if (// empty input
        HopRewriteUtils.isEmpty(input)) {
            // reorg-operation-specific rewrite
            Hop hnew = null;
            if (rhi.getOp() == ReOrgOp.TRANSPOSE)
                hnew = HopRewriteUtils.createDataGenOp(input, true, input, true, 0);
            else if (rhi.getOp() == ReOrgOp.REV)
                hnew = HopRewriteUtils.createDataGenOp(input, 0);
            else if (rhi.getOp() == ReOrgOp.DIAG) {
                if (HopRewriteUtils.isDimsKnown(input)) {
                    if (// diagv2m
                    input.getDim2() == 1)
                        hnew = HopRewriteUtils.createDataGenOp(input, false, input, true, 0);
                    else
                        // diagm2v
                        hnew = HopRewriteUtils.createDataGenOpByVal(HopRewriteUtils.createValueHop(input, true), new LiteralOp(1), 0);
                }
            } else if (rhi.getOp() == ReOrgOp.RESHAPE)
                hnew = HopRewriteUtils.createDataGenOpByVal(rhi.getInput().get(1), rhi.getInput().get(2), 0);
            // modify dag if one of the above rules applied
            if (hnew != null) {
                HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
                hi = hnew;
                LOG.debug("Applied simplifyEmptyReorgOperation");
            }
        }
    }
    return hi;
}
Also used : ReorgOp(org.apache.sysml.hops.ReorgOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 28 with LiteralOp

use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSigmoidMMChains.

private static Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos) {
    Hop hnew = null;
    if (// all patterns subrooted by W *
    HopRewriteUtils.isBinary(hi, OpOp2.MULT) && // not applied for vector-vector mult
    hi.getDim2() > 1 && // prevent mv
    HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && hi.getInput().get(0).getDataType() == DataType.MATRIX && // sigmoid/log
    hi.getInput().get(1) instanceof UnaryOp) {
        UnaryOp uop = (UnaryOp) hi.getInput().get(1);
        boolean appliedPattern = false;
        // Pattern 1) W * sigmoid(Y%*%t(X)) (basic)
        if (uop.getOp() == OpOp1.SIGMOID && uop.getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0), true)) {
            Hop W = hi.getInput().get(0);
            Hop Y = uop.getInput().get(0).getInput().get(0);
            Hop tX = uop.getInput().get(0).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(tX)) {
                tX = HopRewriteUtils.createTranspose(tX);
            } else
                tX = tX.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, false);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedSigmoid1 (line " + hi.getBeginLine() + ")");
        }
        // Pattern 2) W * sigmoid(-(Y%*%t(X))) (minus)
        if (!appliedPattern && uop.getOp() == OpOp1.SIGMOID && HopRewriteUtils.isBinary(uop.getInput().get(0), OpOp2.MINUS) && uop.getInput().get(0).getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) uop.getInput().get(0).getInput().get(0)) == 0 && uop.getInput().get(0).getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(1).getInput().get(0), true)) {
            Hop W = hi.getInput().get(0);
            Hop Y = uop.getInput().get(0).getInput().get(1).getInput().get(0);
            Hop tX = uop.getInput().get(0).getInput().get(1).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(tX)) {
                tX = HopRewriteUtils.createTranspose(tX);
            } else
                tX = tX.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, true);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedSigmoid2 (line " + hi.getBeginLine() + ")");
        }
        // Pattern 3) W * log(sigmoid(Y%*%t(X))) (log)
        if (!appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && uop.getInput().get(0).getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0).getInput().get(0), true)) {
            Hop W = hi.getInput().get(0);
            Hop Y = uop.getInput().get(0).getInput().get(0).getInput().get(0);
            Hop tX = uop.getInput().get(0).getInput().get(0).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(tX)) {
                tX = HopRewriteUtils.createTranspose(tX);
            } else
                tX = tX.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, false);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedSigmoid3 (line " + hi.getBeginLine() + ")");
        }
        // Pattern 4) W * log(sigmoid(-(Y%*%t(X)))) (log_minus)
        if (!appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && HopRewriteUtils.isBinary(uop.getInput().get(0).getInput().get(0), OpOp2.MINUS)) {
            BinaryOp bop = (BinaryOp) uop.getInput().get(0).getInput().get(0);
            if (bop.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) bop.getInput().get(0)) == 0 && bop.getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(bop.getInput().get(1).getInput().get(0), true)) {
                Hop W = hi.getInput().get(0);
                Hop Y = bop.getInput().get(1).getInput().get(0);
                Hop tX = bop.getInput().get(1).getInput().get(1);
                if (!HopRewriteUtils.isTransposeOperation(tX)) {
                    tX = HopRewriteUtils.createTranspose(tX);
                } else
                    tX = tX.getInput().get(0);
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, true);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedSigmoid4 (line " + hi.getBeginLine() + ")");
            }
        }
    }
    // relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 29 with LiteralOp

use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyEmptySortOperation.

private static Hop simplifyEmptySortOperation(Hop parent, Hop hi, int pos) {
    // order(X, indexreturn=TRUE) -> seq(1,nrow(X),1)
    if (hi instanceof ReorgOp && ((ReorgOp) hi).getOp() == ReOrgOp.SORT) {
        ReorgOp rhi = (ReorgOp) hi;
        Hop input = rhi.getInput().get(0);
        if (// empty input
        HopRewriteUtils.isEmpty(input)) {
            // reorg-operation-specific rewrite
            Hop hnew = null;
            boolean ixret = false;
            if (// index return known
            rhi.getInput().get(3) instanceof LiteralOp) {
                ixret = HopRewriteUtils.getBooleanValue((LiteralOp) rhi.getInput().get(3));
                if (ixret)
                    hnew = HopRewriteUtils.createSeqDataGenOp(input);
                else
                    hnew = HopRewriteUtils.createDataGenOp(input, 0);
            }
            // modify dag if one of the above rules applied
            if (hnew != null) {
                HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
                hi = hnew;
                LOG.debug("Applied simplifyEmptySortOperation (indexreturn=" + ixret + ").");
            }
        }
    }
    return hi;
}
Also used : ReorgOp(org.apache.sysml.hops.ReorgOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 30 with LiteralOp

use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyDotProductSum.

/**
 * NOTE: dot-product-sum could be also applied to sum(a*b). However, we
 * restrict ourselfs to sum(a^2) and transitively sum(a*a) since a general mm
 * a%*%b on MR can be also counter-productive (e.g., MMCJ) while tsmm is always
 * beneficial.
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop simplifyDotProductSum(Hop parent, Hop hi, int pos) {
    // w/o materialization of intermediates
    if (// sum
    hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM && // full aggregate
    ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // vector (for correctness)
    hi.getInput().get(0).getDim2() == 1) {
        Hop baLeft = null;
        Hop baRight = null;
        // check for ^2 w/o multiple consumers
        Hop hi2 = hi.getInput().get(0);
        // check for sum(v^2), might have been rewritten from sum(v*v)
        if (HopRewriteUtils.isBinary(hi2, OpOp2.POW) && hi2.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) hi2.getInput().get(1)) == 2 && // no other consumer than sum
        hi2.getParent().size() == 1) {
            Hop input = hi2.getInput().get(0);
            baLeft = input;
            baRight = input;
        } else // check for sum(v1*v2), but prevent to rewrite sum(v1*v2*v3) which is later compiled into a ta+* lop
        if (// no other consumer than sum
        HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1) && hi2.getInput().get(0).getDim2() == 1 && hi2.getInput().get(1).getDim2() == 1 && !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT) && !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT) && (!ALLOW_SUM_PRODUCT_REWRITES || !(// do not rewrite (A^2)*B
        HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW) && // let tak+* handle it
        hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp && ((LiteralOp) hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2)) && (!ALLOW_SUM_PRODUCT_REWRITES || !(// do not rewrite B*(A^2)
        HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW) && // let tak+* handle it
        hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp && ((LiteralOp) hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2))) {
            baLeft = hi2.getInput().get(0);
            baRight = hi2.getInput().get(1);
        }
        // perform actual rewrite (if necessary)
        if (baLeft != null && baRight != null) {
            // create new operator chain
            ReorgOp trans = HopRewriteUtils.createTranspose(baLeft);
            AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, baRight);
            UnaryOp cast = HopRewriteUtils.createUnary(mmult, OpOp1.CAST_AS_SCALAR);
            // rehang new subdag under parent node
            HopRewriteUtils.replaceChildReference(parent, hi, cast, pos);
            HopRewriteUtils.cleanupUnreferenced(hi, hi2);
            hi = cast;
            LOG.debug("Applied simplifyDotProductSum.");
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp) LiteralOp(org.apache.sysml.hops.LiteralOp)

Aggregations

LiteralOp (org.apache.sysml.hops.LiteralOp)102 Hop (org.apache.sysml.hops.Hop)88 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)30 BinaryOp (org.apache.sysml.hops.BinaryOp)27 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)25 UnaryOp (org.apache.sysml.hops.UnaryOp)24 DataOp (org.apache.sysml.hops.DataOp)23 IndexingOp (org.apache.sysml.hops.IndexingOp)17 ArrayList (java.util.ArrayList)13 DataGenOp (org.apache.sysml.hops.DataGenOp)13 LeftIndexingOp (org.apache.sysml.hops.LeftIndexingOp)12 HashMap (java.util.HashMap)11 DataIdentifier (org.apache.sysml.parser.DataIdentifier)10 ReorgOp (org.apache.sysml.hops.ReorgOp)9 MatrixObject (org.apache.sysml.runtime.controlprogram.caching.MatrixObject)9 HopsException (org.apache.sysml.hops.HopsException)8 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)8 StatementBlock (org.apache.sysml.parser.StatementBlock)8 TernaryOp (org.apache.sysml.hops.TernaryOp)7 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)6