Search in sources :

Example 6 with IndexingOp

use of org.apache.sysml.hops.IndexingOp in project incubator-systemml by apache.

the class HopRewriteUtils method createScalarIndexing.

public static Hop createScalarIndexing(Hop input, long rix, long cix) {
    LiteralOp row = new LiteralOp(rix);
    LiteralOp col = new LiteralOp(cix);
    IndexingOp ix = new IndexingOp("tmp", DataType.MATRIX, ValueType.DOUBLE, input, row, row, col, col, true, true);
    ix.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock());
    copyLineNumbers(input, ix);
    ix.refreshSizeInformation();
    return createUnary(ix, OpOp1.CAST_AS_SCALAR);
}
Also used : IndexingOp(org.apache.sysml.hops.IndexingOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 7 with IndexingOp

use of org.apache.sysml.hops.IndexingOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method removeEmptyRightIndexing.

private Hop removeEmptyRightIndexing(Hop parent, Hop hi, int pos) throws HopsException {
    if (//indexing op
    hi instanceof IndexingOp && hi.getDataType() == DataType.MATRIX) {
        Hop input = hi.getInput().get(0);
        if (//nnz input known and empty
        input.getNnz() == 0 && //output dims known
        HopRewriteUtils.isDimsKnown(hi)) {
            //remove unnecessary right indexing
            Hop hnew = HopRewriteUtils.createDataGenOpByVal(new LiteralOp(hi.getDim1()), new LiteralOp(hi.getDim2()), 0);
            HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
            HopRewriteUtils.cleanupUnreferenced(hi, input);
            hi = hnew;
            LOG.debug("Applied removeEmptyRightIndexing");
        }
    }
    return hi;
}
Also used : IndexingOp(org.apache.sysml.hops.IndexingOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 8 with IndexingOp

use of org.apache.sysml.hops.IndexingOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifySlicedMatrixMult.

private Hop simplifySlicedMatrixMult(Hop parent, Hop hi, int pos) throws HopsException {
    //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1] 
    if (hi instanceof IndexingOp && ((IndexingOp) hi).isRowLowerEqualsUpper() && ((IndexingOp) hi).isColLowerEqualsUpper() && //rix is single mm consumer
    hi.getInput().get(0).getParent().size() == 1 && HopRewriteUtils.isMatrixMultiply(hi.getInput().get(0))) {
        Hop mm = hi.getInput().get(0);
        Hop X = mm.getInput().get(0);
        Hop Y = mm.getInput().get(1);
        //rl==ru
        Hop rowExpr = hi.getInput().get(1);
        //cl==cu
        Hop colExpr = hi.getInput().get(3);
        HopRewriteUtils.removeAllChildReferences(mm);
        //create new indexing operations
        IndexingOp ix1 = new IndexingOp("tmp1", DataType.MATRIX, ValueType.DOUBLE, X, rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(X, false), true, false);
        ix1.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
        ix1.refreshSizeInformation();
        IndexingOp ix2 = new IndexingOp("tmp2", DataType.MATRIX, ValueType.DOUBLE, Y, new LiteralOp(1), HopRewriteUtils.createValueHop(Y, true), colExpr, colExpr, false, true);
        ix2.setOutputBlocksizes(Y.getRowsInBlock(), Y.getColsInBlock());
        ix2.refreshSizeInformation();
        //rewire matrix mult over ix1 and ix2
        HopRewriteUtils.addChildReference(mm, ix1, 0);
        HopRewriteUtils.addChildReference(mm, ix2, 1);
        mm.refreshSizeInformation();
        hi = mm;
        LOG.debug("Applied simplifySlicedMatrixMult");
    }
    return hi;
}
Also used : IndexingOp(org.apache.sysml.hops.IndexingOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 9 with IndexingOp

use of org.apache.sysml.hops.IndexingOp in project incubator-systemml by apache.

the class RewriteForLoopVectorization method vectorizeScalarAggregate.

private StatementBlock vectorizeScalarAggregate(StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar) throws HopsException {
    StatementBlock ret = sb;
    //check missing and supported increment values
    if (!(increment != null && increment instanceof LiteralOp && ((LiteralOp) increment).getDoubleValue() == 1.0)) {
        return ret;
    }
    //check for applicability
    boolean leftScalar = false;
    boolean rightScalar = false;
    //row or col
    boolean rowIx = false;
    if (csb.get_hops() != null && csb.get_hops().size() == 1) {
        Hop root = csb.get_hops().get(0);
        if (root.getDataType() == DataType.SCALAR && root.getInput().get(0) instanceof BinaryOp) {
            BinaryOp bop = (BinaryOp) root.getInput().get(0);
            Hop left = bop.getInput().get(0);
            Hop right = bop.getInput().get(1);
            //check for left scalar plus
            if (HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && left instanceof DataOp && left.getDataType() == DataType.SCALAR && root.getName().equals(left.getName()) && right instanceof UnaryOp && ((UnaryOp) right).getOp() == OpOp1.CAST_AS_SCALAR && right.getInput().get(0) instanceof IndexingOp) {
                IndexingOp ix = (IndexingOp) right.getInput().get(0);
                if (ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar)) {
                    leftScalar = true;
                    rowIx = true;
                } else if (ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar)) {
                    leftScalar = true;
                    rowIx = false;
                }
            } else //check for right scalar plus
            if (HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && right instanceof DataOp && right.getDataType() == DataType.SCALAR && root.getName().equals(right.getName()) && left instanceof UnaryOp && ((UnaryOp) left).getOp() == OpOp1.CAST_AS_SCALAR && left.getInput().get(0) instanceof IndexingOp) {
                IndexingOp ix = (IndexingOp) left.getInput().get(0);
                if (ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar)) {
                    rightScalar = true;
                    rowIx = true;
                } else if (ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar)) {
                    rightScalar = true;
                    rowIx = false;
                }
            }
        }
    }
    //apply rewrite if possible
    if (leftScalar || rightScalar) {
        Hop root = csb.get_hops().get(0);
        BinaryOp bop = (BinaryOp) root.getInput().get(0);
        Hop cast = bop.getInput().get(leftScalar ? 1 : 0);
        Hop ix = cast.getInput().get(0);
        int aggOpPos = HopRewriteUtils.getValidOpPos(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS);
        AggOp aggOp = MAP_SCALAR_AGGREGATE_TARGET_OPS[aggOpPos];
        //replace cast with sum
        AggUnaryOp newSum = HopRewriteUtils.createAggUnaryOp(ix, aggOp, Direction.RowCol);
        HopRewriteUtils.removeChildReference(cast, ix);
        HopRewriteUtils.removeChildReference(bop, cast);
        HopRewriteUtils.addChildReference(bop, newSum, leftScalar ? 1 : 0);
        //modify indexing expression according to loop predicate from-to
        //NOTE: any redundant index operations are removed via dynamic algebraic simplification rewrites
        int index1 = rowIx ? 1 : 3;
        int index2 = rowIx ? 2 : 4;
        HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index1), from, index1);
        HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index2), to, index2);
        //update indexing size information
        if (rowIx)
            ((IndexingOp) ix).setRowLowerEqualsUpper(false);
        else
            ((IndexingOp) ix).setColLowerEqualsUpper(false);
        ix.refreshSizeInformation();
        ret = csb;
        LOG.debug("Applied vectorizeScalarSumForLoop.");
    }
    return ret;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) IndexingOp(org.apache.sysml.hops.IndexingOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggOp(org.apache.sysml.hops.Hop.AggOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) DataOp(org.apache.sysml.hops.DataOp) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 10 with IndexingOp

use of org.apache.sysml.hops.IndexingOp in project incubator-systemml by apache.

the class RewriteIndexingVectorization method vectorizeLeftIndexing.

@SuppressWarnings("unchecked")
private void vectorizeLeftIndexing(Hop hop) throws HopsException {
    if (//left indexing
    hop instanceof LeftIndexingOp) {
        LeftIndexingOp ihop0 = (LeftIndexingOp) hop;
        boolean isSingleRow = ihop0.getRowLowerEqualsUpper();
        boolean isSingleCol = ihop0.getColLowerEqualsUpper();
        boolean appliedRow = false;
        if (isSingleRow && isSingleCol) {
            //collect simple chains (w/o multiple consumers) of left indexing ops
            ArrayList<Hop> ihops = new ArrayList<Hop>();
            ihops.add(ihop0);
            Hop current = ihop0;
            while (current.getInput().get(0) instanceof LeftIndexingOp) {
                LeftIndexingOp tmp = (LeftIndexingOp) current.getInput().get(0);
                if (//multiple consumers, i.e., not a simple chain
                tmp.getParent().size() > 1 || //row merge not applicable
                !((LeftIndexingOp) tmp).getRowLowerEqualsUpper() || //not the same row
                tmp.getInput().get(2) != ihop0.getInput().get(2) || //target is single column or unknown 
                tmp.getInput().get(0).getDim2() <= 1) {
                    break;
                }
                ihops.add(tmp);
                current = tmp;
            }
            //apply rewrite if found candidates
            if (ihops.size() > 1) {
                Hop input = current.getInput().get(0);
                //keep before reset
                Hop rowExpr = ihop0.getInput().get(2);
                //new row indexing operator
                IndexingOp newRix = new IndexingOp("tmp1", input.getDataType(), input.getValueType(), input, rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(input, false), true, false);
                HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
                newRix.refreshSizeInformation();
                //rewrite bottom left indexing operator
                //input data
                HopRewriteUtils.removeChildReference(current, input);
                HopRewriteUtils.addChildReference(current, newRix, 0);
                //reset row index all candidates and refresh sizes (bottom-up)
                for (int i = ihops.size() - 1; i >= 0; i--) {
                    Hop c = ihops.get(i);
                    //row lower expr
                    HopRewriteUtils.replaceChildReference(c, c.getInput().get(2), new LiteralOp(1), 2);
                    //row upper expr
                    HopRewriteUtils.replaceChildReference(c, c.getInput().get(3), new LiteralOp(1), 3);
                    ((LeftIndexingOp) c).setRowLowerEqualsUpper(true);
                    c.refreshSizeInformation();
                }
                //new row left indexing operator (for all parents, only intermediates are guaranteed to have 1 parent)
                //(note: it's important to clone the parent list before creating newLix on top of ihop0)
                ArrayList<Hop> ihop0parents = (ArrayList<Hop>) ihop0.getParent().clone();
                ArrayList<Integer> ihop0parentsPos = new ArrayList<Integer>();
                for (Hop parent : ihop0parents) {
                    int posp = HopRewriteUtils.getChildReferencePos(parent, ihop0);
                    //input data
                    HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp);
                    ihop0parentsPos.add(posp);
                }
                LeftIndexingOp newLix = new LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), input, ihop0, rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(input, false), true, false);
                HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
                newLix.refreshSizeInformation();
                for (int i = 0; i < ihop0parentsPos.size(); i++) {
                    Hop parent = ihop0parents.get(i);
                    int posp = ihop0parentsPos.get(i);
                    HopRewriteUtils.addChildReference(parent, newLix, posp);
                }
                appliedRow = true;
                LOG.debug("Applied vectorizeLeftIndexingRow");
            }
        }
        if (isSingleRow && isSingleCol && !appliedRow) {
            //collect simple chains (w/o multiple consumers) of left indexing ops
            ArrayList<Hop> ihops = new ArrayList<Hop>();
            ihops.add(ihop0);
            Hop current = ihop0;
            while (current.getInput().get(0) instanceof LeftIndexingOp) {
                LeftIndexingOp tmp = (LeftIndexingOp) current.getInput().get(0);
                if (//multiple consumers, i.e., not a simple chain
                tmp.getParent().size() > 1 || //row merge not applicable
                !((LeftIndexingOp) tmp).getColLowerEqualsUpper() || //not the same col
                tmp.getInput().get(4) != ihop0.getInput().get(4) || //target is single row or unknown
                tmp.getInput().get(0).getDim1() <= 1) {
                    break;
                }
                ihops.add(tmp);
                current = tmp;
            }
            //apply rewrite if found candidates
            if (ihops.size() > 1) {
                Hop input = current.getInput().get(0);
                //keep before reset
                Hop colExpr = ihop0.getInput().get(4);
                //new row indexing operator
                IndexingOp newRix = new IndexingOp("tmp1", input.getDataType(), input.getValueType(), input, new LiteralOp(1), HopRewriteUtils.createValueHop(input, true), colExpr, colExpr, false, true);
                HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
                newRix.refreshSizeInformation();
                //rewrite bottom left indexing operator
                //input data
                HopRewriteUtils.removeChildReference(current, input);
                HopRewriteUtils.addChildReference(current, newRix, 0);
                //reset col index all candidates and refresh sizes (bottom-up)
                for (int i = ihops.size() - 1; i >= 0; i--) {
                    Hop c = ihops.get(i);
                    //col lower expr
                    HopRewriteUtils.replaceChildReference(c, c.getInput().get(4), new LiteralOp(1), 4);
                    //col upper expr
                    HopRewriteUtils.replaceChildReference(c, c.getInput().get(5), new LiteralOp(1), 5);
                    ((LeftIndexingOp) c).setColLowerEqualsUpper(true);
                    c.refreshSizeInformation();
                }
                //new row left indexing operator (for all parents, only intermediates are guaranteed to have 1 parent)
                //(note: it's important to clone the parent list before creating newLix on top of ihop0)
                ArrayList<Hop> ihop0parents = (ArrayList<Hop>) ihop0.getParent().clone();
                ArrayList<Integer> ihop0parentsPos = new ArrayList<Integer>();
                for (Hop parent : ihop0parents) {
                    int posp = HopRewriteUtils.getChildReferencePos(parent, ihop0);
                    //input data
                    HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp);
                    ihop0parentsPos.add(posp);
                }
                LeftIndexingOp newLix = new LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), input, ihop0, new LiteralOp(1), HopRewriteUtils.createValueHop(input, true), colExpr, colExpr, false, true);
                HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
                newLix.refreshSizeInformation();
                for (int i = 0; i < ihop0parentsPos.size(); i++) {
                    Hop parent = ihop0parents.get(i);
                    int posp = ihop0parentsPos.get(i);
                    HopRewriteUtils.addChildReference(parent, newLix, posp);
                }
                appliedRow = true;
                LOG.debug("Applied vectorizeLeftIndexingCol");
            }
        }
    }
}
Also used : IndexingOp(org.apache.sysml.hops.IndexingOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp) ArrayList(java.util.ArrayList) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp)

Aggregations

IndexingOp (org.apache.sysml.hops.IndexingOp)17 Hop (org.apache.sysml.hops.Hop)16 LiteralOp (org.apache.sysml.hops.LiteralOp)15 LeftIndexingOp (org.apache.sysml.hops.LeftIndexingOp)10 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)8 DataOp (org.apache.sysml.hops.DataOp)8 UnaryOp (org.apache.sysml.hops.UnaryOp)7 BinaryOp (org.apache.sysml.hops.BinaryOp)5 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)4 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)4 StatementBlock (org.apache.sysml.parser.StatementBlock)4 WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)4 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)3 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)3 TernaryOp (org.apache.sysml.hops.TernaryOp)3 MatrixObject (org.apache.sysml.runtime.controlprogram.caching.MatrixObject)3 ArrayList (java.util.ArrayList)2 ReorgOp (org.apache.sysml.hops.ReorgOp)2 CNode (org.apache.sysml.hops.codegen.cplan.CNode)2 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)2