Search in sources :

Example 6 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyColSumsMVMult.

private Hop simplifyColSumsMVMult(Hop parent, Hop hi, int pos) throws HopsException {
    //removed by other rewrite if unnecessary, i.e., if Y==t(Z)
    if (hi instanceof AggUnaryOp) {
        AggUnaryOp uhi = (AggUnaryOp) hi;
        Hop input = uhi.getInput().get(0);
        if (//colsums
        uhi.getOp() == AggOp.SUM && uhi.getDirection() == Direction.Col && //b(*) 
        HopRewriteUtils.isBinary(input, OpOp2.MULT)) {
            Hop left = input.getInput().get(0);
            Hop right = input.getInput().get(1);
            if (left.getDim1() > 1 && left.getDim2() > 1 && right.getDim1() > 1 && // MV (col vector)
            right.getDim2() == 1) {
                //create new operators 
                ReorgOp trans = HopRewriteUtils.createTranspose(right);
                AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, left);
                //relink new child
                HopRewriteUtils.replaceChildReference(parent, hi, mmult, pos);
                HopRewriteUtils.cleanupUnreferenced(uhi, input);
                hi = mmult;
                LOG.debug("Applied simplifyColSumsMVMult");
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp)

Example 7 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method pushdownUnaryAggTransposeOperation.

private Hop pushdownUnaryAggTransposeOperation(Hop parent, Hop hi, int pos) {
    if (hi instanceof AggUnaryOp && hi.getParent().size() == 1 && (((AggUnaryOp) hi).getDirection() == Direction.Row || ((AggUnaryOp) hi).getDirection() == Direction.Col) && HopRewriteUtils.isTransposeOperation(hi.getInput().get(0), 1) && HopRewriteUtils.isValidOp(((AggUnaryOp) hi).getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE)) {
        AggUnaryOp uagg = (AggUnaryOp) hi;
        //get input rewire existing operators (remove inner transpose)
        Hop input = uagg.getInput().get(0).getInput().get(0);
        HopRewriteUtils.removeAllChildReferences(hi.getInput().get(0));
        HopRewriteUtils.removeAllChildReferences(hi);
        HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
        //pattern 1: row-aggregate to col aggregate, e.g., rowSums(t(X))->t(colSums(X))
        if (uagg.getDirection() == Direction.Row) {
            uagg.setDirection(Direction.Col);
            LOG.debug("Applied pushdownUnaryAggTransposeOperation1 (line " + hi.getBeginLine() + ").");
        } else //pattern 2: col-aggregate to row aggregate, e.g., colSums(t(X))->t(rowSums(X))
        if (uagg.getDirection() == Direction.Col) {
            uagg.setDirection(Direction.Row);
            LOG.debug("Applied pushdownUnaryAggTransposeOperation2 (line " + hi.getBeginLine() + ").");
        }
        //create outer transpose operation and rewire operators
        HopRewriteUtils.addChildReference(uagg, input);
        uagg.refreshSizeInformation();
        //incl refresh size
        Hop trans = HopRewriteUtils.createTranspose(uagg);
        //by def, same size
        HopRewriteUtils.addChildReference(parent, trans, pos);
        hi = trans;
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop)

Example 8 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyColwiseAggregate.

@SuppressWarnings("unchecked")
private Hop simplifyColwiseAggregate(Hop parent, Hop hi, int pos) throws HopsException {
    if (hi instanceof AggUnaryOp) {
        AggUnaryOp uhi = (AggUnaryOp) hi;
        Hop input = uhi.getInput().get(0);
        if (HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE)) {
            if (uhi.getDirection() == Direction.Col) {
                if (input.getDim1() == 1) {
                    if (uhi.getOp() == AggOp.VAR) {
                        // For the column variance aggregation, if the input is a row vector,
                        // the column variances will each be zero.
                        // Therefore, perform a rewrite from COLVAR(X) to a row vector of zeros.
                        Hop emptyRow = HopRewriteUtils.createDataGenOp(uhi, input, 0);
                        HopRewriteUtils.replaceChildReference(parent, hi, emptyRow, pos);
                        HopRewriteUtils.cleanupUnreferenced(hi, input);
                        hi = emptyRow;
                        LOG.debug("Applied simplifyColwiseAggregate for colVars");
                    } else {
                        // All other valid column aggregations over a row vector will result
                        // in the row vector itself.
                        // Therefore, remove unnecessary col aggregation for 1 row.
                        HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
                        HopRewriteUtils.cleanupUnreferenced(hi);
                        hi = input;
                        LOG.debug("Applied simplifyColwiseAggregate1");
                    }
                } else if (input.getDim2() == 1) {
                    //get old parents (before creating cast over aggregate)
                    ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone();
                    //simplify col-aggregate to full aggregate
                    uhi.setDirection(Direction.RowCol);
                    uhi.setDataType(DataType.SCALAR);
                    //create cast to keep same output datatype
                    UnaryOp cast = HopRewriteUtils.createUnary(uhi, OpOp1.CAST_AS_MATRIX);
                    //rehang cast under all parents
                    for (Hop p : parents) {
                        int ix = HopRewriteUtils.getChildReferencePos(p, hi);
                        HopRewriteUtils.replaceChildReference(p, hi, cast, ix);
                    }
                    hi = cast;
                    LOG.debug("Applied simplifyColwiseAggregate2");
                }
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList)

Example 9 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifySumMatrixMult.

private Hop simplifySumMatrixMult(Hop parent, Hop hi, int pos) {
    //-- if sum not the only consumer, not applied to prevent redundancy 
    if (//sum
    hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM && //A%*%B
    hi.getInput().get(0) instanceof AggBinaryOp && //not dot product
    (hi.getInput().get(0).getDim1() > 1 || hi.getInput().get(0).getDim2() > 1) && //not multiple consumers of matrix mult
    hi.getInput().get(0).getParent().size() == 1) {
        Hop hi2 = hi.getInput().get(0);
        Hop left = hi2.getInput().get(0);
        Hop right = hi2.getInput().get(1);
        //remove link from parent to matrix mult
        HopRewriteUtils.removeChildReference(hi, hi2);
        //create new operators
        Hop root = null;
        //pattern: sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), later rewritten to dot-product
        if (((AggUnaryOp) hi).getDirection() == Direction.RowCol) {
            AggUnaryOp colSum = HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col);
            ReorgOp trans = HopRewriteUtils.createTranspose(colSum);
            AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row);
            root = HopRewriteUtils.createBinary(trans, rowSum, OpOp2.MULT);
            LOG.debug("Applied simplifySumMatrixMult RC.");
        } else //colSums(A%*%B) -> colSums(A)%*%B
        if (((AggUnaryOp) hi).getDirection() == Direction.Col) {
            AggUnaryOp colSum = HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col);
            root = HopRewriteUtils.createMatrixMultiply(colSum, right);
            LOG.debug("Applied simplifySumMatrixMult C.");
        } else //rowSums(A%*%B) -> A%*%rowSums(B)
        if (((AggUnaryOp) hi).getDirection() == Direction.Row) {
            AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row);
            root = HopRewriteUtils.createMatrixMultiply(left, rowSum);
            LOG.debug("Applied simplifySumMatrixMult R.");
        }
        //rehang new subdag under current node (keep hi intact)
        HopRewriteUtils.addChildReference(hi, root, 0);
        hi.refreshSizeInformation();
        //cleanup if only consumer of intermediate
        HopRewriteUtils.cleanupUnreferenced(hi2);
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp)

Example 10 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyDiagMatrixMult.

private Hop simplifyDiagMatrixMult(Hop parent, Hop hi, int pos) {
    if (//diagM2V
    hi instanceof ReorgOp && ((ReorgOp) hi).getOp() == ReOrgOp.DIAG && hi.getDim2() == 1) {
        Hop hi2 = hi.getInput().get(0);
        if (//X%*%Y
        HopRewriteUtils.isMatrixMultiply(hi2)) {
            Hop left = hi2.getInput().get(0);
            Hop right = hi2.getInput().get(1);
            //create new operators (incl refresh size inside for transpose)
            ReorgOp trans = HopRewriteUtils.createTranspose(right);
            BinaryOp mult = HopRewriteUtils.createBinary(left, trans, OpOp2.MULT);
            AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(mult, AggOp.SUM, Direction.Row);
            //rehang new subdag under parent node
            HopRewriteUtils.replaceChildReference(parent, hi, rowSum, pos);
            HopRewriteUtils.cleanupUnreferenced(hi, hi2);
            hi = rowSum;
            LOG.debug("Applied simplifyDiagMatrixMult");
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) ReorgOp(org.apache.sysml.hops.ReorgOp) Hop(org.apache.sysml.hops.Hop) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)31 Hop (org.apache.sysml.hops.Hop)29 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)15 LiteralOp (org.apache.sysml.hops.LiteralOp)13 BinaryOp (org.apache.sysml.hops.BinaryOp)11 ReorgOp (org.apache.sysml.hops.ReorgOp)10 UnaryOp (org.apache.sysml.hops.UnaryOp)10 IndexingOp (org.apache.sysml.hops.IndexingOp)5 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)5 ArrayList (java.util.ArrayList)4 DataOp (org.apache.sysml.hops.DataOp)4 TernaryOp (org.apache.sysml.hops.TernaryOp)4 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)3 CNode (org.apache.sysml.hops.codegen.cplan.CNode)3 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)3 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)3 CNodeUnary (org.apache.sysml.hops.codegen.cplan.CNodeUnary)3 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)2 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)2 CNodeTernary (org.apache.sysml.hops.codegen.cplan.CNodeTernary)2