Search in sources :

Example 36 with StatementBlock

use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.

the class RewriteForLoopVectorization method rewriteStatementBlock.

@Override
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
    if (sb instanceof ForStatementBlock) {
        ForStatementBlock fsb = (ForStatementBlock) sb;
        ForStatement fs = (ForStatement) fsb.getStatement(0);
        Hop from = fsb.getFromHops();
        Hop to = fsb.getToHops();
        Hop incr = fsb.getIncrementHops();
        String iterVar = fsb.getIterPredicate().getIterVar().getName();
        if (// single child block
        fs.getBody() != null && fs.getBody().size() == 1) {
            StatementBlock csb = (StatementBlock) fs.getBody().get(0);
            if (!(// last level block
            csb instanceof WhileStatementBlock || csb instanceof IfStatementBlock || csb instanceof ForStatementBlock)) {
                // AUTO VECTORIZATION PATTERNS
                // Note: unnecessary row or column indexing then later removed via hop rewrites
                // e.g., for(i in a:b){s = s + as.scalar(X[i,2])} -> s = sum(X[a:b,2])
                sb = vectorizeScalarAggregate(sb, csb, from, to, incr, iterVar);
                // e.g., for(i in a:b){X[i,2] = Y[i,1] + Z[i,3]} -> X[a:b,2] = Y[a:b,1] + Z[a:b,3];
                sb = vectorizeElementwiseBinary(sb, csb, from, to, incr, iterVar);
                // e.g., for(i in a:b){X[i,2] = abs(Y[i,1])} -> X[a:b,2] = abs(Y[a:b,1]);
                sb = vectorizeElementwiseUnary(sb, csb, from, to, incr, iterVar);
                // e.g., for(i in a:b){X[7,i] = Y[1,i]} -> X[7,a:b] = Y[1,a:b];
                sb = vectorizeIndexedCopy(sb, csb, from, to, incr, iterVar);
            }
        }
    }
    // that includes the equivalent vectorized operations.
    return Arrays.asList(sb);
}
Also used : ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) Hop(org.apache.sysml.hops.Hop) ForStatement(org.apache.sysml.parser.ForStatement) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 37 with StatementBlock

use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.

the class RewriteForLoopVectorization method vectorizeElementwiseUnary.

private static StatementBlock vectorizeElementwiseUnary(StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar) {
    StatementBlock ret = sb;
    // check supported increment values
    if (!(increment instanceof LiteralOp && ((LiteralOp) increment).getDoubleValue() == 1.0)) {
        return ret;
    }
    // check for applicability
    boolean apply = false;
    // row or col
    boolean rowIx = false;
    if (csb.getHops() != null && csb.getHops().size() == 1) {
        Hop root = csb.getHops().get(0);
        if (root.getDataType() == DataType.MATRIX && root.getInput().get(0) instanceof LeftIndexingOp) {
            LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0);
            Hop lixlhs = lix.getInput().get(0);
            Hop lixrhs = lix.getInput().get(1);
            if (lixlhs instanceof DataOp && lixrhs instanceof UnaryOp && lixrhs.getInput().get(0) instanceof IndexingOp && lixrhs.getInput().get(0).getInput().get(0) instanceof DataOp) {
                boolean[] tmp = checkLeftAndRightIndexing(lix, (IndexingOp) lixrhs.getInput().get(0), itervar);
                apply = tmp[0];
                rowIx = tmp[1];
            }
        }
    }
    // apply rewrite if possible
    if (apply) {
        Hop root = csb.getHops().get(0);
        LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0);
        UnaryOp uop = (UnaryOp) lix.getInput().get(1);
        IndexingOp rix = (IndexingOp) uop.getInput().get(0);
        int index1 = rowIx ? 2 : 4;
        int index2 = rowIx ? 3 : 5;
        // modify left indexing bounds
        HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index1), from, index1);
        HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index2), to, index2);
        // modify right indexing
        HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index1 - 1), from, index1 - 1);
        HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index2 - 1), to, index2 - 1);
        updateLeftAndRightIndexingSizes(rowIx, lix, rix);
        uop.refreshSizeInformation();
        // after uop update
        lix.refreshSizeInformation();
        ret = csb;
        LOG.debug("Applied vectorizeElementwiseUnaryForLoop.");
    }
    return ret;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) IndexingOp(org.apache.sysml.hops.IndexingOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) DataOp(org.apache.sysml.hops.DataOp) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp)

Example 38 with StatementBlock

use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.

the class RewriteForLoopVectorization method vectorizeIndexedCopy.

private static StatementBlock vectorizeIndexedCopy(StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar) {
    StatementBlock ret = sb;
    // check supported increment values
    if (!(increment instanceof LiteralOp && ((LiteralOp) increment).getDoubleValue() == 1.0)) {
        return ret;
    }
    // check for applicability
    boolean apply = false;
    // row or col
    boolean rowIx = false;
    if (csb.getHops() != null && csb.getHops().size() == 1) {
        Hop root = csb.getHops().get(0);
        if (root.getDataType() == DataType.MATRIX && root.getInput().get(0) instanceof LeftIndexingOp) {
            LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0);
            Hop lixlhs = lix.getInput().get(0);
            Hop lixrhs = lix.getInput().get(1);
            if (lixlhs instanceof DataOp && lixrhs instanceof IndexingOp && lixrhs.getInput().get(0) instanceof DataOp) {
                boolean[] tmp = checkLeftAndRightIndexing(lix, (IndexingOp) lixrhs, itervar);
                apply = tmp[0];
                rowIx = tmp[1];
            }
        }
    }
    // apply rewrite if possible
    if (apply) {
        Hop root = csb.getHops().get(0);
        LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0);
        IndexingOp rix = (IndexingOp) lix.getInput().get(1);
        int index1 = rowIx ? 2 : 4;
        int index2 = rowIx ? 3 : 5;
        // modify left indexing bounds
        HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index1), from, index1);
        HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index2), to, index2);
        // modify right indexing
        HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index1 - 1), from, index1 - 1);
        HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index2 - 1), to, index2 - 1);
        updateLeftAndRightIndexingSizes(rowIx, lix, rix);
        ret = csb;
        LOG.debug("Applied vectorizeIndexedCopy.");
    }
    return ret;
}
Also used : IndexingOp(org.apache.sysml.hops.IndexingOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) DataOp(org.apache.sysml.hops.DataOp) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp)

Example 39 with StatementBlock

use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.

the class RewriteForLoopVectorization method vectorizeScalarAggregate.

private static StatementBlock vectorizeScalarAggregate(StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar) {
    StatementBlock ret = sb;
    // check missing and supported increment values
    if (!(increment != null && increment instanceof LiteralOp && ((LiteralOp) increment).getDoubleValue() == 1.0)) {
        return ret;
    }
    // check for applicability
    boolean leftScalar = false;
    boolean rightScalar = false;
    // row or col
    boolean rowIx = false;
    if (csb.getHops() != null && csb.getHops().size() == 1) {
        Hop root = csb.getHops().get(0);
        if (root.getDataType() == DataType.SCALAR && root.getInput().get(0) instanceof BinaryOp) {
            BinaryOp bop = (BinaryOp) root.getInput().get(0);
            Hop left = bop.getInput().get(0);
            Hop right = bop.getInput().get(1);
            // check for left scalar plus
            if (HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && left instanceof DataOp && left.getDataType() == DataType.SCALAR && root.getName().equals(left.getName()) && right instanceof UnaryOp && ((UnaryOp) right).getOp() == OpOp1.CAST_AS_SCALAR && right.getInput().get(0) instanceof IndexingOp) {
                IndexingOp ix = (IndexingOp) right.getInput().get(0);
                if (ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar)) {
                    leftScalar = true;
                    rowIx = true;
                } else if (ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar)) {
                    leftScalar = true;
                    rowIx = false;
                }
            } else // check for right scalar plus
            if (HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && right instanceof DataOp && right.getDataType() == DataType.SCALAR && root.getName().equals(right.getName()) && left instanceof UnaryOp && ((UnaryOp) left).getOp() == OpOp1.CAST_AS_SCALAR && left.getInput().get(0) instanceof IndexingOp) {
                IndexingOp ix = (IndexingOp) left.getInput().get(0);
                if (ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar)) {
                    rightScalar = true;
                    rowIx = true;
                } else if (ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar)) {
                    rightScalar = true;
                    rowIx = false;
                }
            }
        }
    }
    // apply rewrite if possible
    if (leftScalar || rightScalar) {
        Hop root = csb.getHops().get(0);
        BinaryOp bop = (BinaryOp) root.getInput().get(0);
        Hop cast = bop.getInput().get(leftScalar ? 1 : 0);
        Hop ix = cast.getInput().get(0);
        int aggOpPos = HopRewriteUtils.getValidOpPos(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS);
        AggOp aggOp = MAP_SCALAR_AGGREGATE_TARGET_OPS[aggOpPos];
        // replace cast with sum
        AggUnaryOp newSum = HopRewriteUtils.createAggUnaryOp(ix, aggOp, Direction.RowCol);
        HopRewriteUtils.removeChildReference(cast, ix);
        HopRewriteUtils.removeChildReference(bop, cast);
        HopRewriteUtils.addChildReference(bop, newSum, leftScalar ? 1 : 0);
        // modify indexing expression according to loop predicate from-to
        // NOTE: any redundant index operations are removed via dynamic algebraic simplification rewrites
        int index1 = rowIx ? 1 : 3;
        int index2 = rowIx ? 2 : 4;
        HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index1), from, index1);
        HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index2), to, index2);
        // update indexing size information
        if (rowIx)
            ((IndexingOp) ix).setRowLowerEqualsUpper(false);
        else
            ((IndexingOp) ix).setColLowerEqualsUpper(false);
        ix.refreshSizeInformation();
        ret = csb;
        LOG.debug("Applied vectorizeScalarSumForLoop.");
    }
    return ret;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) IndexingOp(org.apache.sysml.hops.IndexingOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggOp(org.apache.sysml.hops.Hop.AggOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) DataOp(org.apache.sysml.hops.DataOp) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 40 with StatementBlock

use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.

the class RewriteMergeBlockSequence method rewriteStatementBlocks.

@Override
public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate) {
    if (sbs == null || sbs.isEmpty())
        return sbs;
    // execute binary merging iterations until fixpoint
    ArrayList<StatementBlock> tmpList = new ArrayList<>(sbs);
    boolean merged = true;
    while (merged) {
        merged = false;
        for (int i = 0; i < tmpList.size() - 1; i++) {
            StatementBlock sb1 = tmpList.get(i);
            StatementBlock sb2 = tmpList.get(i + 1);
            if (HopRewriteUtils.isLastLevelStatementBlock(sb1) && HopRewriteUtils.isLastLevelStatementBlock(sb2) && !sb1.isSplitDag() && !sb2.isSplitDag() && !(hasExternalFunctionOpRootWithSideEffect(sb1) && hasExternalFunctionOpRootWithSideEffect(sb2)) && (!hasFunctionOpRoot(sb1) || !hasFunctionIOConflict(sb1, sb2)) && (!hasFunctionOpRoot(sb2) || !hasFunctionIOConflict(sb2, sb1))) {
                // note: we intend to merge sb1 into sb2 to connect data dependencies
                // however, we work with a temporary list of root nodes to preserve
                // the original order of roots, which affects prints w/o dependencies
                ArrayList<Hop> sb1Hops = sb1.getHops();
                ArrayList<Hop> sb2Hops = sb2.getHops();
                ArrayList<Hop> newHops = new ArrayList<>();
                // determine transient read inputs s2
                Hop.resetVisitStatus(sb2Hops);
                HashMap<String, Hop> treads = new HashMap<>();
                HashMap<String, Hop> twrites = new HashMap<>();
                for (Hop root : sb2Hops) rCollectTransientReadWrites(root, treads, twrites);
                Hop.resetVisitStatus(sb2Hops);
                // merge hop dags of s1 and s2
                Hop.resetVisitStatus(sb1Hops);
                for (Hop root : sb1Hops) {
                    // connect transient writes s1 and reads s2
                    if (HopRewriteUtils.isData(root, DataOpTypes.TRANSIENTWRITE) && treads.containsKey(root.getName())) {
                        // rewire transient write and transient read
                        Hop tread = treads.get(root.getName());
                        Hop in = root.getInput().get(0);
                        for (Hop parent : new ArrayList<>(tread.getParent())) HopRewriteUtils.replaceChildReference(parent, tread, in);
                        HopRewriteUtils.removeAllChildReferences(root);
                        // add transient write if necessary
                        if (!twrites.containsKey(root.getName()) && sb2.liveOut().containsVariable(root.getName())) {
                            newHops.add(HopRewriteUtils.createDataOp(root.getName(), in, DataOpTypes.TRANSIENTWRITE));
                        }
                    } else // add remaining roots from s1 to s2
                    if (!(HopRewriteUtils.isData(root, DataOpTypes.TRANSIENTWRITE) && (twrites.containsKey(root.getName()) || !sb2.liveOut().containsVariable(root.getName())))) {
                        newHops.add(root);
                    }
                }
                // clear partial hops from the merged statement block to avoid problems with
                // other statement block rewrites that iterate over the original program
                sb1Hops.clear();
                // append all root nodes of s2 after root nodes of s1
                newHops.addAll(sb2Hops);
                sb2.setHops(newHops);
                // run common-subexpression elimination
                Hop.resetVisitStatus(sb2.getHops());
                rewriter.rewriteHopDAG(sb2.getHops(), new ProgramRewriteStatus());
                // modify live variable sets of s2
                // liveOut remains unchanged
                sb2.setLiveIn(sb1.liveIn());
                sb2.setGen(VariableSet.minus(VariableSet.union(sb1.getGen(), sb2.getGen()), sb1.getKill()));
                sb2.setKill(VariableSet.union(sb1.getKill(), sb2.getKill()));
                sb2.setReadVariables(VariableSet.union(sb1.variablesRead(), sb2.variablesRead()));
                sb2.setUpdatedVariables(VariableSet.union(sb1.variablesUpdated(), sb2.variablesUpdated()));
                LOG.debug("Applied mergeStatementBlockSequences " + "(blocks of lines " + sb1.getBeginLine() + "-" + sb1.getEndLine() + " and " + sb2.getBeginLine() + "-" + sb2.getEndLine() + ").");
                // modify line numbers of s2
                sb2.setBeginLine(sb1.getBeginLine());
                sb2.setBeginColumn(sb1.getBeginColumn());
                // remove sb1 from list of statement blocks
                tmpList.remove(i);
                merged = true;
                // for
                break;
            }
        }
    }
    return tmpList;
}
Also used : HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) Hop(org.apache.sysml.hops.Hop) StatementBlock(org.apache.sysml.parser.StatementBlock) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock)

Aggregations

StatementBlock (org.apache.sysml.parser.StatementBlock)67 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)57 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)57 WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)57 FunctionStatementBlock (org.apache.sysml.parser.FunctionStatementBlock)39 Hop (org.apache.sysml.hops.Hop)28 ArrayList (java.util.ArrayList)24 FunctionStatement (org.apache.sysml.parser.FunctionStatement)22 IfStatement (org.apache.sysml.parser.IfStatement)22 ForStatement (org.apache.sysml.parser.ForStatement)20 WhileStatement (org.apache.sysml.parser.WhileStatement)19 ParForStatementBlock (org.apache.sysml.parser.ParForStatementBlock)18 ForProgramBlock (org.apache.sysml.runtime.controlprogram.ForProgramBlock)18 IfProgramBlock (org.apache.sysml.runtime.controlprogram.IfProgramBlock)16 WhileProgramBlock (org.apache.sysml.runtime.controlprogram.WhileProgramBlock)16 FunctionProgramBlock (org.apache.sysml.runtime.controlprogram.FunctionProgramBlock)13 ProgramBlock (org.apache.sysml.runtime.controlprogram.ProgramBlock)13 HashSet (java.util.HashSet)11 ExternalFunctionStatement (org.apache.sysml.parser.ExternalFunctionStatement)11 LocalVariableMap (org.apache.sysml.runtime.controlprogram.LocalVariableMap)11