Search in sources :

Example 11 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyColSumsMVMult.

private static Hop simplifyColSumsMVMult(Hop parent, Hop hi, int pos) {
    // removed by other rewrite if unnecessary, i.e., if Y==t(Z)
    if (hi instanceof AggUnaryOp) {
        AggUnaryOp uhi = (AggUnaryOp) hi;
        Hop input = uhi.getInput().get(0);
        if (// colsums
        uhi.getOp() == AggOp.SUM && uhi.getDirection() == Direction.Col && // b(*)
        HopRewriteUtils.isBinary(input, OpOp2.MULT)) {
            Hop left = input.getInput().get(0);
            Hop right = input.getInput().get(1);
            if (left.getDim1() > 1 && left.getDim2() > 1 && right.getDim1() > 1 && // MV (col vector)
            right.getDim2() == 1) {
                // create new operators
                ReorgOp trans = HopRewriteUtils.createTranspose(right);
                AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, left);
                // relink new child
                HopRewriteUtils.replaceChildReference(parent, hi, mmult, pos);
                HopRewriteUtils.cleanupUnreferenced(uhi, input);
                hi = mmult;
                LOG.debug("Applied simplifyColSumsMVMult");
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp)

Example 12 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyDotProductSum.

/**
 * NOTE: dot-product-sum could be also applied to sum(a*b). However, we
 * restrict ourselfs to sum(a^2) and transitively sum(a*a) since a general mm
 * a%*%b on MR can be also counter-productive (e.g., MMCJ) while tsmm is always
 * beneficial.
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop simplifyDotProductSum(Hop parent, Hop hi, int pos) {
    // w/o materialization of intermediates
    if (// sum
    hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM && // full aggregate
    ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // vector (for correctness)
    hi.getInput().get(0).getDim2() == 1) {
        Hop baLeft = null;
        Hop baRight = null;
        // check for ^2 w/o multiple consumers
        Hop hi2 = hi.getInput().get(0);
        // check for sum(v^2), might have been rewritten from sum(v*v)
        if (HopRewriteUtils.isBinary(hi2, OpOp2.POW) && hi2.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) hi2.getInput().get(1)) == 2 && // no other consumer than sum
        hi2.getParent().size() == 1) {
            Hop input = hi2.getInput().get(0);
            baLeft = input;
            baRight = input;
        } else // check for sum(v1*v2), but prevent to rewrite sum(v1*v2*v3) which is later compiled into a ta+* lop
        if (// no other consumer than sum
        HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1) && hi2.getInput().get(0).getDim2() == 1 && hi2.getInput().get(1).getDim2() == 1 && !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT) && !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT) && (!ALLOW_SUM_PRODUCT_REWRITES || !(// do not rewrite (A^2)*B
        HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW) && // let tak+* handle it
        hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp && ((LiteralOp) hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2)) && (!ALLOW_SUM_PRODUCT_REWRITES || !(// do not rewrite B*(A^2)
        HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW) && // let tak+* handle it
        hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp && ((LiteralOp) hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2))) {
            baLeft = hi2.getInput().get(0);
            baRight = hi2.getInput().get(1);
        }
        // perform actual rewrite (if necessary)
        if (baLeft != null && baRight != null) {
            // create new operator chain
            ReorgOp trans = HopRewriteUtils.createTranspose(baLeft);
            AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, baRight);
            UnaryOp cast = HopRewriteUtils.createUnary(mmult, OpOp1.CAST_AS_SCALAR);
            // rehang new subdag under parent node
            HopRewriteUtils.replaceChildReference(parent, hi, cast, pos);
            HopRewriteUtils.cleanupUnreferenced(hi, hi2);
            hi = cast;
            LOG.debug("Applied simplifyDotProductSum.");
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 13 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyDiagMatrixMult.

private static Hop simplifyDiagMatrixMult(Hop parent, Hop hi, int pos) {
    if (// diagM2V
    hi instanceof ReorgOp && ((ReorgOp) hi).getOp() == ReOrgOp.DIAG && hi.getDim2() == 1) {
        Hop hi2 = hi.getInput().get(0);
        if (// X%*%Y
        HopRewriteUtils.isMatrixMultiply(hi2)) {
            Hop left = hi2.getInput().get(0);
            Hop right = hi2.getInput().get(1);
            // create new operators (incl refresh size inside for transpose)
            ReorgOp trans = HopRewriteUtils.createTranspose(right);
            BinaryOp mult = HopRewriteUtils.createBinary(left, trans, OpOp2.MULT);
            AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(mult, AggOp.SUM, Direction.Row);
            // rehang new subdag under parent node
            HopRewriteUtils.replaceChildReference(parent, hi, rowSum, pos);
            HopRewriteUtils.cleanupUnreferenced(hi, hi2);
            hi = rowSum;
            LOG.debug("Applied simplifyDiagMatrixMult");
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) ReorgOp(org.apache.sysml.hops.ReorgOp) Hop(org.apache.sysml.hops.Hop) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 14 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method pushdownSumBinaryMult.

private static Hop pushdownSumBinaryMult(Hop parent, Hop hi, int pos) {
    // pattern:  sum(lamda*X) -> lamda*sum(X)
    if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // only one parent which is the sum
    ((AggUnaryOp) hi).getOp() == Hop.AggOp.SUM && HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.MULT, 1) && ((hi.getInput().get(0).getInput().get(0).getDataType() == DataType.SCALAR && hi.getInput().get(0).getInput().get(1).getDataType() == DataType.MATRIX) || (hi.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getInput().get(1).getDataType() == DataType.SCALAR))) {
        Hop operand1 = hi.getInput().get(0).getInput().get(0);
        Hop operand2 = hi.getInput().get(0).getInput().get(1);
        // check which operand is the Scalar and which is the matrix
        Hop lamda = (operand1.getDataType() == DataType.SCALAR) ? operand1 : operand2;
        Hop matrix = (operand1.getDataType() == DataType.MATRIX) ? operand1 : operand2;
        AggUnaryOp aggOp = HopRewriteUtils.createAggUnaryOp(matrix, AggOp.SUM, Direction.RowCol);
        Hop bop = HopRewriteUtils.createBinary(lamda, aggOp, OpOp2.MULT);
        HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
        LOG.debug("Applied pushdownSumBinaryMult.");
        return bop;
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop)

Example 15 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method pushdownUnaryAggTransposeOperation.

private static Hop pushdownUnaryAggTransposeOperation(Hop parent, Hop hi, int pos) {
    if (hi instanceof AggUnaryOp && hi.getParent().size() == 1 && (((AggUnaryOp) hi).getDirection() == Direction.Row || ((AggUnaryOp) hi).getDirection() == Direction.Col) && HopRewriteUtils.isTransposeOperation(hi.getInput().get(0), 1) && HopRewriteUtils.isValidOp(((AggUnaryOp) hi).getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE)) {
        AggUnaryOp uagg = (AggUnaryOp) hi;
        // get input rewire existing operators (remove inner transpose)
        Hop input = uagg.getInput().get(0).getInput().get(0);
        HopRewriteUtils.removeAllChildReferences(hi.getInput().get(0));
        HopRewriteUtils.removeAllChildReferences(hi);
        HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
        // pattern 1: row-aggregate to col aggregate, e.g., rowSums(t(X))->t(colSums(X))
        if (uagg.getDirection() == Direction.Row) {
            uagg.setDirection(Direction.Col);
            LOG.debug("Applied pushdownUnaryAggTransposeOperation1 (line " + hi.getBeginLine() + ").");
        } else // pattern 2: col-aggregate to row aggregate, e.g., colSums(t(X))->t(rowSums(X))
        if (uagg.getDirection() == Direction.Col) {
            uagg.setDirection(Direction.Row);
            LOG.debug("Applied pushdownUnaryAggTransposeOperation2 (line " + hi.getBeginLine() + ").");
        }
        // create outer transpose operation and rewire operators
        HopRewriteUtils.addChildReference(uagg, input);
        uagg.refreshSizeInformation();
        // incl refresh size
        Hop trans = HopRewriteUtils.createTranspose(uagg);
        // by def, same size
        HopRewriteUtils.addChildReference(parent, trans, pos);
        hi = trans;
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop)

Aggregations

AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)36 Hop (org.apache.sysml.hops.Hop)33 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)19 LiteralOp (org.apache.sysml.hops.LiteralOp)14 BinaryOp (org.apache.sysml.hops.BinaryOp)12 ReorgOp (org.apache.sysml.hops.ReorgOp)11 UnaryOp (org.apache.sysml.hops.UnaryOp)11 IndexingOp (org.apache.sysml.hops.IndexingOp)7 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)6 ArrayList (java.util.ArrayList)5 TernaryOp (org.apache.sysml.hops.TernaryOp)5 DataOp (org.apache.sysml.hops.DataOp)4 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)4 CNode (org.apache.sysml.hops.codegen.cplan.CNode)3 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)3 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)3 CNodeUnary (org.apache.sysml.hops.codegen.cplan.CNodeUnary)3 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)2 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)2 CNodeTernary (org.apache.sysml.hops.codegen.cplan.CNodeTernary)2