Search in sources :

Example 41 with IndexingOp

use of org.apache.sysml.hops.IndexingOp in project systemml by apache.

the class RewriteIndexingVectorization method vectorizeLeftIndexing.

@SuppressWarnings("unchecked")
private static Hop vectorizeLeftIndexing(Hop hop) {
    Hop ret = hop;
    if (// left indexing
    hop instanceof LeftIndexingOp) {
        LeftIndexingOp ihop0 = (LeftIndexingOp) hop;
        boolean isSingleRow = ihop0.isRowLowerEqualsUpper();
        boolean isSingleCol = ihop0.isColLowerEqualsUpper();
        boolean appliedRow = false;
        if (isSingleRow && isSingleCol) {
            // collect simple chains (w/o multiple consumers) of left indexing ops
            ArrayList<Hop> ihops = new ArrayList<>();
            ihops.add(ihop0);
            Hop current = ihop0;
            while (current.getInput().get(0) instanceof LeftIndexingOp) {
                LeftIndexingOp tmp = (LeftIndexingOp) current.getInput().get(0);
                if (// multiple consumers, i.e., not a simple chain
                tmp.getParent().size() > 1 || // row merge not applicable
                !((LeftIndexingOp) tmp).isRowLowerEqualsUpper() || // not the same row
                tmp.getInput().get(2) != ihop0.getInput().get(2) || // target is single column or unknown
                tmp.getInput().get(0).getDim2() <= 1) {
                    break;
                }
                ihops.add(tmp);
                current = tmp;
            }
            // apply rewrite if found candidates
            if (ihops.size() > 1) {
                Hop input = current.getInput().get(0);
                // keep before reset
                Hop rowExpr = ihop0.getInput().get(2);
                // new row indexing operator
                IndexingOp newRix = new IndexingOp("tmp1", input.getDataType(), input.getValueType(), input, rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(input, false), true, false);
                HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
                newRix.refreshSizeInformation();
                // reset visit status of copied hops (otherwise hidden by left indexing)
                for (Hop c : newRix.getInput()) c.resetVisitStatus();
                // rewrite bottom left indexing operator
                // input data
                HopRewriteUtils.removeChildReference(current, input);
                HopRewriteUtils.addChildReference(current, newRix, 0);
                // reset row index all candidates and refresh sizes (bottom-up)
                for (int i = ihops.size() - 1; i >= 0; i--) {
                    Hop c = ihops.get(i);
                    // row lower expr
                    HopRewriteUtils.replaceChildReference(c, c.getInput().get(2), new LiteralOp(1), 2);
                    // row upper expr
                    HopRewriteUtils.replaceChildReference(c, c.getInput().get(3), new LiteralOp(1), 3);
                    ((LeftIndexingOp) c).setRowLowerEqualsUpper(true);
                    c.refreshSizeInformation();
                }
                // new row left indexing operator (for all parents, only intermediates are guaranteed to have 1 parent)
                // (note: it's important to clone the parent list before creating newLix on top of ihop0)
                ArrayList<Hop> ihop0parents = (ArrayList<Hop>) ihop0.getParent().clone();
                ArrayList<Integer> ihop0parentsPos = new ArrayList<>();
                for (Hop parent : ihop0parents) {
                    int posp = HopRewriteUtils.getChildReferencePos(parent, ihop0);
                    // input data
                    HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp);
                    ihop0parentsPos.add(posp);
                }
                LeftIndexingOp newLix = new LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), input, ihop0, rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(input, false), true, false);
                HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
                newLix.refreshSizeInformation();
                // reset visit status of copied hops (otherwise hidden by left indexing)
                for (Hop c : newLix.getInput()) c.resetVisitStatus();
                for (int i = 0; i < ihop0parentsPos.size(); i++) {
                    Hop parent = ihop0parents.get(i);
                    int posp = ihop0parentsPos.get(i);
                    HopRewriteUtils.addChildReference(parent, newLix, posp);
                }
                appliedRow = true;
                ret = newLix;
                LOG.debug("Applied vectorizeLeftIndexingRow for hop " + hop.getHopID());
            }
        }
        if (isSingleRow && isSingleCol && !appliedRow) {
            // collect simple chains (w/o multiple consumers) of left indexing ops
            ArrayList<Hop> ihops = new ArrayList<>();
            ihops.add(ihop0);
            Hop current = ihop0;
            while (current.getInput().get(0) instanceof LeftIndexingOp) {
                LeftIndexingOp tmp = (LeftIndexingOp) current.getInput().get(0);
                if (// multiple consumers, i.e., not a simple chain
                tmp.getParent().size() > 1 || // row merge not applicable
                !((LeftIndexingOp) tmp).isColLowerEqualsUpper() || // not the same col
                tmp.getInput().get(4) != ihop0.getInput().get(4) || // target is single row or unknown
                tmp.getInput().get(0).getDim1() <= 1) {
                    break;
                }
                ihops.add(tmp);
                current = tmp;
            }
            // apply rewrite if found candidates
            if (ihops.size() > 1) {
                Hop input = current.getInput().get(0);
                // keep before reset
                Hop colExpr = ihop0.getInput().get(4);
                // new row indexing operator
                IndexingOp newRix = new IndexingOp("tmp1", input.getDataType(), input.getValueType(), input, new LiteralOp(1), HopRewriteUtils.createValueHop(input, true), colExpr, colExpr, false, true);
                HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
                newRix.refreshSizeInformation();
                // reset visit status of copied hops (otherwise hidden by left indexing)
                for (Hop c : newRix.getInput()) c.resetVisitStatus();
                // rewrite bottom left indexing operator
                // input data
                HopRewriteUtils.removeChildReference(current, input);
                HopRewriteUtils.addChildReference(current, newRix, 0);
                // reset col index all candidates and refresh sizes (bottom-up)
                for (int i = ihops.size() - 1; i >= 0; i--) {
                    Hop c = ihops.get(i);
                    // col lower expr
                    HopRewriteUtils.replaceChildReference(c, c.getInput().get(4), new LiteralOp(1), 4);
                    // col upper expr
                    HopRewriteUtils.replaceChildReference(c, c.getInput().get(5), new LiteralOp(1), 5);
                    ((LeftIndexingOp) c).setColLowerEqualsUpper(true);
                    c.refreshSizeInformation();
                }
                // new row left indexing operator (for all parents, only intermediates are guaranteed to have 1 parent)
                // (note: it's important to clone the parent list before creating newLix on top of ihop0)
                ArrayList<Hop> ihop0parents = (ArrayList<Hop>) ihop0.getParent().clone();
                ArrayList<Integer> ihop0parentsPos = new ArrayList<>();
                for (Hop parent : ihop0parents) {
                    int posp = HopRewriteUtils.getChildReferencePos(parent, ihop0);
                    // input data
                    HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp);
                    ihop0parentsPos.add(posp);
                }
                LeftIndexingOp newLix = new LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), input, ihop0, new LiteralOp(1), HopRewriteUtils.createValueHop(input, true), colExpr, colExpr, false, true);
                HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
                newLix.refreshSizeInformation();
                // reset visit status of copied hops (otherwise hidden by left indexing)
                for (Hop c : newLix.getInput()) c.resetVisitStatus();
                for (int i = 0; i < ihop0parentsPos.size(); i++) {
                    Hop parent = ihop0parents.get(i);
                    int posp = ihop0parentsPos.get(i);
                    HopRewriteUtils.addChildReference(parent, newLix, posp);
                }
                ret = newLix;
                LOG.debug("Applied vectorizeLeftIndexingCol for hop " + hop.getHopID());
            }
        }
    }
    return ret;
}
Also used : IndexingOp(org.apache.sysml.hops.IndexingOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) LiteralOp(org.apache.sysml.hops.LiteralOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp)

Example 42 with IndexingOp

use of org.apache.sysml.hops.IndexingOp in project systemml by apache.

the class RewriteIndexingVectorization method vectorizeRightLeftIndexingChains.

private static Hop vectorizeRightLeftIndexingChains(Hop hi) {
    // check for valid root operator
    if (!(hi instanceof LeftIndexingOp && hi.getInput().get(1) instanceof IndexingOp && hi.getInput().get(1).getParent().size() == 1))
        return hi;
    LeftIndexingOp lix0 = (LeftIndexingOp) hi;
    IndexingOp rix0 = (IndexingOp) hi.getInput().get(1);
    if (!(lix0.isRowLowerEqualsUpper() || lix0.isColLowerEqualsUpper()) || lix0.isRowLowerEqualsUpper() != rix0.isRowLowerEqualsUpper() || lix0.isColLowerEqualsUpper() != rix0.isColLowerEqualsUpper())
        return hi;
    boolean row = lix0.isRowLowerEqualsUpper();
    if (!((row ? HopRewriteUtils.isFullRowIndexing(lix0) : HopRewriteUtils.isFullColumnIndexing(lix0)) && (row ? HopRewriteUtils.isFullRowIndexing(rix0) : HopRewriteUtils.isFullColumnIndexing(rix0))))
        return hi;
    // determine consecutive left-right indexing chains for rows/columns
    List<LeftIndexingOp> lix = new ArrayList<>();
    lix.add(lix0);
    List<IndexingOp> rix = new ArrayList<>();
    rix.add(rix0);
    LeftIndexingOp clix = lix0;
    IndexingOp crix = rix0;
    while (isConsecutiveLeftRightIndexing(clix, crix, clix.getInput().get(0)) && clix.getInput().get(0).getParent().size() == 1 && clix.getInput().get(0).getInput().get(1).getParent().size() == 1) {
        clix = (LeftIndexingOp) clix.getInput().get(0);
        crix = (IndexingOp) clix.getInput().get(1);
        lix.add(clix);
        rix.add(crix);
    }
    // rewrite pattern if at least two consecutive pairs
    if (lix.size() >= 2) {
        IndexingOp rixn = rix.get(rix.size() - 1);
        Hop rlrix = rixn.getInput().get(1);
        Hop rurix = row ? HopRewriteUtils.createBinary(rlrix, new LiteralOp(rix.size() - 1), OpOp2.PLUS) : rixn.getInput().get(2);
        Hop clrix = rixn.getInput().get(3);
        Hop curix = row ? rixn.getInput().get(4) : HopRewriteUtils.createBinary(clrix, new LiteralOp(rix.size() - 1), OpOp2.PLUS);
        IndexingOp rixNew = HopRewriteUtils.createIndexingOp(rixn.getInput().get(0), rlrix, rurix, clrix, curix);
        LeftIndexingOp lixn = lix.get(rix.size() - 1);
        Hop rllix = lixn.getInput().get(2);
        Hop rulix = row ? HopRewriteUtils.createBinary(rllix, new LiteralOp(lix.size() - 1), OpOp2.PLUS) : lixn.getInput().get(3);
        Hop cllix = lixn.getInput().get(4);
        Hop culix = row ? lixn.getInput().get(5) : HopRewriteUtils.createBinary(cllix, new LiteralOp(lix.size() - 1), OpOp2.PLUS);
        LeftIndexingOp lixNew = HopRewriteUtils.createLeftIndexingOp(lixn.getInput().get(0), rixNew, rllix, rulix, cllix, culix);
        // rewire parents and childs
        HopRewriteUtils.replaceChildReference(hi.getParent().get(0), hi, lixNew);
        for (int i = 0; i < lix.size(); i++) {
            HopRewriteUtils.removeAllChildReferences(lix.get(i));
            HopRewriteUtils.removeAllChildReferences(rix.get(i));
        }
        hi = lixNew;
        LOG.debug("Applied vectorizeRightLeftIndexingChains (line " + hi.getBeginLine() + ")");
    }
    return hi;
}
Also used : IndexingOp(org.apache.sysml.hops.IndexingOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp) ArrayList(java.util.ArrayList) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp)

Aggregations

IndexingOp (org.apache.sysml.hops.IndexingOp)42 Hop (org.apache.sysml.hops.Hop)35 LiteralOp (org.apache.sysml.hops.LiteralOp)32 LeftIndexingOp (org.apache.sysml.hops.LeftIndexingOp)23 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)19 DataOp (org.apache.sysml.hops.DataOp)16 UnaryOp (org.apache.sysml.hops.UnaryOp)15 BinaryOp (org.apache.sysml.hops.BinaryOp)11 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)9 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)8 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)8 StatementBlock (org.apache.sysml.parser.StatementBlock)8 WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)8 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)7 TernaryOp (org.apache.sysml.hops.TernaryOp)7 ArrayList (java.util.ArrayList)6 MatrixObject (org.apache.sysml.runtime.controlprogram.caching.MatrixObject)6 ReorgOp (org.apache.sysml.hops.ReorgOp)5 CNode (org.apache.sysml.hops.codegen.cplan.CNode)4 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)4