Search in sources :

Example 11 with LeftIndexingOp

use of org.apache.sysml.hops.LeftIndexingOp in project incubator-systemml by apache.

the class RewriteForLoopVectorization method vectorizeIndexedCopy.

private static StatementBlock vectorizeIndexedCopy(StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar) {
    StatementBlock ret = sb;
    // check supported increment values
    if (!(increment instanceof LiteralOp && ((LiteralOp) increment).getDoubleValue() == 1.0)) {
        return ret;
    }
    // check for applicability
    boolean apply = false;
    // row or col
    boolean rowIx = false;
    if (csb.getHops() != null && csb.getHops().size() == 1) {
        Hop root = csb.getHops().get(0);
        if (root.getDataType() == DataType.MATRIX && root.getInput().get(0) instanceof LeftIndexingOp) {
            LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0);
            Hop lixlhs = lix.getInput().get(0);
            Hop lixrhs = lix.getInput().get(1);
            if (lixlhs instanceof DataOp && lixrhs instanceof IndexingOp && lixrhs.getInput().get(0) instanceof DataOp) {
                boolean[] tmp = checkLeftAndRightIndexing(lix, (IndexingOp) lixrhs, itervar);
                apply = tmp[0];
                rowIx = tmp[1];
            }
        }
    }
    // apply rewrite if possible
    if (apply) {
        Hop root = csb.getHops().get(0);
        LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0);
        IndexingOp rix = (IndexingOp) lix.getInput().get(1);
        int index1 = rowIx ? 2 : 4;
        int index2 = rowIx ? 3 : 5;
        // modify left indexing bounds
        HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index1), from, index1);
        HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index2), to, index2);
        // modify right indexing
        HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index1 - 1), from, index1 - 1);
        HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index2 - 1), to, index2 - 1);
        updateLeftAndRightIndexingSizes(rowIx, lix, rix);
        ret = csb;
        LOG.debug("Applied vectorizeIndexedCopy.");
    }
    return ret;
}
Also used : IndexingOp(org.apache.sysml.hops.IndexingOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) DataOp(org.apache.sysml.hops.DataOp) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp)

Example 12 with LeftIndexingOp

use of org.apache.sysml.hops.LeftIndexingOp in project incubator-systemml by apache.

the class RewriteIndexingVectorization method vectorizeRightLeftIndexingChains.

private static Hop vectorizeRightLeftIndexingChains(Hop hi) {
    // check for valid root operator
    if (!(hi instanceof LeftIndexingOp && hi.getInput().get(1) instanceof IndexingOp && hi.getInput().get(1).getParent().size() == 1))
        return hi;
    LeftIndexingOp lix0 = (LeftIndexingOp) hi;
    IndexingOp rix0 = (IndexingOp) hi.getInput().get(1);
    if (!(lix0.isRowLowerEqualsUpper() || lix0.isColLowerEqualsUpper()) || lix0.isRowLowerEqualsUpper() != rix0.isRowLowerEqualsUpper() || lix0.isColLowerEqualsUpper() != rix0.isColLowerEqualsUpper())
        return hi;
    boolean row = lix0.isRowLowerEqualsUpper();
    if (!((row ? HopRewriteUtils.isFullRowIndexing(lix0) : HopRewriteUtils.isFullColumnIndexing(lix0)) && (row ? HopRewriteUtils.isFullRowIndexing(rix0) : HopRewriteUtils.isFullColumnIndexing(rix0))))
        return hi;
    // determine consecutive left-right indexing chains for rows/columns
    List<LeftIndexingOp> lix = new ArrayList<>();
    lix.add(lix0);
    List<IndexingOp> rix = new ArrayList<>();
    rix.add(rix0);
    LeftIndexingOp clix = lix0;
    IndexingOp crix = rix0;
    while (isConsecutiveLeftRightIndexing(clix, crix, clix.getInput().get(0)) && clix.getInput().get(0).getParent().size() == 1 && clix.getInput().get(0).getInput().get(1).getParent().size() == 1) {
        clix = (LeftIndexingOp) clix.getInput().get(0);
        crix = (IndexingOp) clix.getInput().get(1);
        lix.add(clix);
        rix.add(crix);
    }
    // rewrite pattern if at least two consecutive pairs
    if (lix.size() >= 2) {
        IndexingOp rixn = rix.get(rix.size() - 1);
        Hop rlrix = rixn.getInput().get(1);
        Hop rurix = row ? HopRewriteUtils.createBinary(rlrix, new LiteralOp(rix.size() - 1), OpOp2.PLUS) : rixn.getInput().get(2);
        Hop clrix = rixn.getInput().get(3);
        Hop curix = row ? rixn.getInput().get(4) : HopRewriteUtils.createBinary(clrix, new LiteralOp(rix.size() - 1), OpOp2.PLUS);
        IndexingOp rixNew = HopRewriteUtils.createIndexingOp(rixn.getInput().get(0), rlrix, rurix, clrix, curix);
        LeftIndexingOp lixn = lix.get(rix.size() - 1);
        Hop rllix = lixn.getInput().get(2);
        Hop rulix = row ? HopRewriteUtils.createBinary(rllix, new LiteralOp(lix.size() - 1), OpOp2.PLUS) : lixn.getInput().get(3);
        Hop cllix = lixn.getInput().get(4);
        Hop culix = row ? lixn.getInput().get(5) : HopRewriteUtils.createBinary(cllix, new LiteralOp(lix.size() - 1), OpOp2.PLUS);
        LeftIndexingOp lixNew = HopRewriteUtils.createLeftIndexingOp(lixn.getInput().get(0), rixNew, rllix, rulix, cllix, culix);
        // rewire parents and childs
        HopRewriteUtils.replaceChildReference(hi.getParent().get(0), hi, lixNew);
        for (int i = 0; i < lix.size(); i++) {
            HopRewriteUtils.removeAllChildReferences(lix.get(i));
            HopRewriteUtils.removeAllChildReferences(rix.get(i));
        }
        hi = lixNew;
        LOG.debug("Applied vectorizeRightLeftIndexingChains (line " + hi.getBeginLine() + ")");
    }
    return hi;
}
Also used : IndexingOp(org.apache.sysml.hops.IndexingOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp) ArrayList(java.util.ArrayList) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp)

Example 13 with LeftIndexingOp

use of org.apache.sysml.hops.LeftIndexingOp in project incubator-systemml by apache.

the class RewriteIndexingVectorization method vectorizeLeftIndexing.

@SuppressWarnings("unchecked")
private static Hop vectorizeLeftIndexing(Hop hop) {
    Hop ret = hop;
    if (// left indexing
    hop instanceof LeftIndexingOp) {
        LeftIndexingOp ihop0 = (LeftIndexingOp) hop;
        boolean isSingleRow = ihop0.isRowLowerEqualsUpper();
        boolean isSingleCol = ihop0.isColLowerEqualsUpper();
        boolean appliedRow = false;
        if (isSingleRow && isSingleCol) {
            // collect simple chains (w/o multiple consumers) of left indexing ops
            ArrayList<Hop> ihops = new ArrayList<>();
            ihops.add(ihop0);
            Hop current = ihop0;
            while (current.getInput().get(0) instanceof LeftIndexingOp) {
                LeftIndexingOp tmp = (LeftIndexingOp) current.getInput().get(0);
                if (// multiple consumers, i.e., not a simple chain
                tmp.getParent().size() > 1 || // row merge not applicable
                !((LeftIndexingOp) tmp).isRowLowerEqualsUpper() || // not the same row
                tmp.getInput().get(2) != ihop0.getInput().get(2) || // target is single column or unknown
                tmp.getInput().get(0).getDim2() <= 1) {
                    break;
                }
                ihops.add(tmp);
                current = tmp;
            }
            // apply rewrite if found candidates
            if (ihops.size() > 1) {
                Hop input = current.getInput().get(0);
                // keep before reset
                Hop rowExpr = ihop0.getInput().get(2);
                // new row indexing operator
                IndexingOp newRix = new IndexingOp("tmp1", input.getDataType(), input.getValueType(), input, rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(input, false), true, false);
                HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
                newRix.refreshSizeInformation();
                // reset visit status of copied hops (otherwise hidden by left indexing)
                for (Hop c : newRix.getInput()) c.resetVisitStatus();
                // rewrite bottom left indexing operator
                // input data
                HopRewriteUtils.removeChildReference(current, input);
                HopRewriteUtils.addChildReference(current, newRix, 0);
                // reset row index all candidates and refresh sizes (bottom-up)
                for (int i = ihops.size() - 1; i >= 0; i--) {
                    Hop c = ihops.get(i);
                    // row lower expr
                    HopRewriteUtils.replaceChildReference(c, c.getInput().get(2), new LiteralOp(1), 2);
                    // row upper expr
                    HopRewriteUtils.replaceChildReference(c, c.getInput().get(3), new LiteralOp(1), 3);
                    ((LeftIndexingOp) c).setRowLowerEqualsUpper(true);
                    c.refreshSizeInformation();
                }
                // new row left indexing operator (for all parents, only intermediates are guaranteed to have 1 parent)
                // (note: it's important to clone the parent list before creating newLix on top of ihop0)
                ArrayList<Hop> ihop0parents = (ArrayList<Hop>) ihop0.getParent().clone();
                ArrayList<Integer> ihop0parentsPos = new ArrayList<>();
                for (Hop parent : ihop0parents) {
                    int posp = HopRewriteUtils.getChildReferencePos(parent, ihop0);
                    // input data
                    HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp);
                    ihop0parentsPos.add(posp);
                }
                LeftIndexingOp newLix = new LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), input, ihop0, rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(input, false), true, false);
                HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
                newLix.refreshSizeInformation();
                // reset visit status of copied hops (otherwise hidden by left indexing)
                for (Hop c : newLix.getInput()) c.resetVisitStatus();
                for (int i = 0; i < ihop0parentsPos.size(); i++) {
                    Hop parent = ihop0parents.get(i);
                    int posp = ihop0parentsPos.get(i);
                    HopRewriteUtils.addChildReference(parent, newLix, posp);
                }
                appliedRow = true;
                ret = newLix;
                LOG.debug("Applied vectorizeLeftIndexingRow for hop " + hop.getHopID());
            }
        }
        if (isSingleRow && isSingleCol && !appliedRow) {
            // collect simple chains (w/o multiple consumers) of left indexing ops
            ArrayList<Hop> ihops = new ArrayList<>();
            ihops.add(ihop0);
            Hop current = ihop0;
            while (current.getInput().get(0) instanceof LeftIndexingOp) {
                LeftIndexingOp tmp = (LeftIndexingOp) current.getInput().get(0);
                if (// multiple consumers, i.e., not a simple chain
                tmp.getParent().size() > 1 || // row merge not applicable
                !((LeftIndexingOp) tmp).isColLowerEqualsUpper() || // not the same col
                tmp.getInput().get(4) != ihop0.getInput().get(4) || // target is single row or unknown
                tmp.getInput().get(0).getDim1() <= 1) {
                    break;
                }
                ihops.add(tmp);
                current = tmp;
            }
            // apply rewrite if found candidates
            if (ihops.size() > 1) {
                Hop input = current.getInput().get(0);
                // keep before reset
                Hop colExpr = ihop0.getInput().get(4);
                // new row indexing operator
                IndexingOp newRix = new IndexingOp("tmp1", input.getDataType(), input.getValueType(), input, new LiteralOp(1), HopRewriteUtils.createValueHop(input, true), colExpr, colExpr, false, true);
                HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
                newRix.refreshSizeInformation();
                // reset visit status of copied hops (otherwise hidden by left indexing)
                for (Hop c : newRix.getInput()) c.resetVisitStatus();
                // rewrite bottom left indexing operator
                // input data
                HopRewriteUtils.removeChildReference(current, input);
                HopRewriteUtils.addChildReference(current, newRix, 0);
                // reset col index all candidates and refresh sizes (bottom-up)
                for (int i = ihops.size() - 1; i >= 0; i--) {
                    Hop c = ihops.get(i);
                    // col lower expr
                    HopRewriteUtils.replaceChildReference(c, c.getInput().get(4), new LiteralOp(1), 4);
                    // col upper expr
                    HopRewriteUtils.replaceChildReference(c, c.getInput().get(5), new LiteralOp(1), 5);
                    ((LeftIndexingOp) c).setColLowerEqualsUpper(true);
                    c.refreshSizeInformation();
                }
                // new row left indexing operator (for all parents, only intermediates are guaranteed to have 1 parent)
                // (note: it's important to clone the parent list before creating newLix on top of ihop0)
                ArrayList<Hop> ihop0parents = (ArrayList<Hop>) ihop0.getParent().clone();
                ArrayList<Integer> ihop0parentsPos = new ArrayList<>();
                for (Hop parent : ihop0parents) {
                    int posp = HopRewriteUtils.getChildReferencePos(parent, ihop0);
                    // input data
                    HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp);
                    ihop0parentsPos.add(posp);
                }
                LeftIndexingOp newLix = new LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), input, ihop0, new LiteralOp(1), HopRewriteUtils.createValueHop(input, true), colExpr, colExpr, false, true);
                HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
                newLix.refreshSizeInformation();
                // reset visit status of copied hops (otherwise hidden by left indexing)
                for (Hop c : newLix.getInput()) c.resetVisitStatus();
                for (int i = 0; i < ihop0parentsPos.size(); i++) {
                    Hop parent = ihop0parents.get(i);
                    int posp = ihop0parentsPos.get(i);
                    HopRewriteUtils.addChildReference(parent, newLix, posp);
                }
                ret = newLix;
                LOG.debug("Applied vectorizeLeftIndexingCol for hop " + hop.getHopID());
            }
        }
    }
    return ret;
}
Also used : IndexingOp(org.apache.sysml.hops.IndexingOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) LiteralOp(org.apache.sysml.hops.LiteralOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp)

Example 14 with LeftIndexingOp

use of org.apache.sysml.hops.LeftIndexingOp in project incubator-systemml by apache.

the class HopRewriteUtils method createLeftIndexingOp.

public static LeftIndexingOp createLeftIndexingOp(Hop lhs, Hop rhs, Hop rl, Hop ru, Hop cl, Hop cu) {
    LeftIndexingOp ix = new LeftIndexingOp("tmp", DataType.MATRIX, ValueType.DOUBLE, lhs, rhs, rl, ru, cl, cu, rl == ru, cl == cu);
    ix.setOutputBlocksizes(lhs.getRowsInBlock(), lhs.getColsInBlock());
    copyLineNumbers(lhs, ix);
    ix.refreshSizeInformation();
    return ix;
}
Also used : LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp)

Example 15 with LeftIndexingOp

use of org.apache.sysml.hops.LeftIndexingOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method removeEmptyLeftIndexing.

private static Hop removeEmptyLeftIndexing(Hop parent, Hop hi, int pos) {
    if (// left indexing op
    hi instanceof LeftIndexingOp && hi.getDataType() == DataType.MATRIX) {
        // lhs matrix
        Hop input1 = hi.getInput().get(0);
        // rhs matrix
        Hop input2 = hi.getInput().get(1);
        if (// nnz original known and empty
        input1.getNnz() == 0 && // nnz input known and empty
        input2.getNnz() == 0) {
            // remove unnecessary right indexing
            Hop hnew = HopRewriteUtils.createDataGenOp(input1, 0);
            HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
            HopRewriteUtils.cleanupUnreferenced(hi, input2);
            hi = hnew;
            LOG.debug("Applied removeEmptyLeftIndexing");
        }
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp)

Aggregations

LeftIndexingOp (org.apache.sysml.hops.LeftIndexingOp)16 Hop (org.apache.sysml.hops.Hop)10 LiteralOp (org.apache.sysml.hops.LiteralOp)7 IndexingOp (org.apache.sysml.hops.IndexingOp)6 DataOp (org.apache.sysml.hops.DataOp)4 MatrixObject (org.apache.sysml.runtime.controlprogram.caching.MatrixObject)4 ArrayList (java.util.ArrayList)3 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)3 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)3 StatementBlock (org.apache.sysml.parser.StatementBlock)3 WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)3 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)2 BinaryOp (org.apache.sysml.hops.BinaryOp)2 UnaryOp (org.apache.sysml.hops.UnaryOp)2 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)1 FunctionOp (org.apache.sysml.hops.FunctionOp)1 MultiThreadedHop (org.apache.sysml.hops.Hop.MultiThreadedHop)1 ReorgOp (org.apache.sysml.hops.ReorgOp)1