Search in sources :

Example 51 with WhileStatementBlock

use of org.apache.sysml.parser.WhileStatementBlock in project incubator-systemml by apache.

the class RewriteForLoopVectorization method rewriteStatementBlock.

@Override
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
    if (sb instanceof ForStatementBlock) {
        ForStatementBlock fsb = (ForStatementBlock) sb;
        ForStatement fs = (ForStatement) fsb.getStatement(0);
        Hop from = fsb.getFromHops();
        Hop to = fsb.getToHops();
        Hop incr = fsb.getIncrementHops();
        String iterVar = fsb.getIterPredicate().getIterVar().getName();
        if (// single child block
        fs.getBody() != null && fs.getBody().size() == 1) {
            StatementBlock csb = (StatementBlock) fs.getBody().get(0);
            if (!(// last level block
            csb instanceof WhileStatementBlock || csb instanceof IfStatementBlock || csb instanceof ForStatementBlock)) {
                // AUTO VECTORIZATION PATTERNS
                // Note: unnecessary row or column indexing then later removed via hop rewrites
                // e.g., for(i in a:b){s = s + as.scalar(X[i,2])} -> s = sum(X[a:b,2])
                sb = vectorizeScalarAggregate(sb, csb, from, to, incr, iterVar);
                // e.g., for(i in a:b){X[i,2] = Y[i,1] + Z[i,3]} -> X[a:b,2] = Y[a:b,1] + Z[a:b,3];
                sb = vectorizeElementwiseBinary(sb, csb, from, to, incr, iterVar);
                // e.g., for(i in a:b){X[i,2] = abs(Y[i,1])} -> X[a:b,2] = abs(Y[a:b,1]);
                sb = vectorizeElementwiseUnary(sb, csb, from, to, incr, iterVar);
                // e.g., for(i in a:b){X[7,i] = Y[1,i]} -> X[7,a:b] = Y[1,a:b];
                sb = vectorizeIndexedCopy(sb, csb, from, to, incr, iterVar);
            }
        }
    }
    // that includes the equivalent vectorized operations.
    return Arrays.asList(sb);
}
Also used : ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) Hop(org.apache.sysml.hops.Hop) ForStatement(org.apache.sysml.parser.ForStatement) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 52 with WhileStatementBlock

use of org.apache.sysml.parser.WhileStatementBlock in project incubator-systemml by apache.

the class RewriteMarkLoopVariablesUpdateInPlace method rewriteStatementBlock.

@Override
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status) {
    if (DMLScript.rtplatform == RUNTIME_PLATFORM.HADOOP || DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK) {
        // nothing to do here, return original statement block
        return Arrays.asList(sb);
    }
    if (// incl parfor
    sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock) {
        ArrayList<String> candidates = new ArrayList<>();
        VariableSet updated = sb.variablesUpdated();
        VariableSet liveout = sb.liveOut();
        for (String varname : updated.getVariableNames()) {
            if (updated.getVariable(varname).getDataType() == DataType.MATRIX && // exclude local vars
            liveout.containsVariable(varname)) {
                if (sb instanceof WhileStatementBlock) {
                    WhileStatement wstmt = (WhileStatement) sb.getStatement(0);
                    if (rIsApplicableForUpdateInPlace(wstmt.getBody(), varname))
                        candidates.add(varname);
                } else if (sb instanceof ForStatementBlock) {
                    ForStatement wstmt = (ForStatement) sb.getStatement(0);
                    if (rIsApplicableForUpdateInPlace(wstmt.getBody(), varname))
                        candidates.add(varname);
                }
            }
        }
        sb.setUpdateInPlaceVars(candidates);
    }
    // return modified statement block
    return Arrays.asList(sb);
}
Also used : ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) VariableSet(org.apache.sysml.parser.VariableSet) ArrayList(java.util.ArrayList) WhileStatement(org.apache.sysml.parser.WhileStatement) ForStatement(org.apache.sysml.parser.ForStatement) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock)

Example 53 with WhileStatementBlock

use of org.apache.sysml.parser.WhileStatementBlock in project incubator-systemml by apache.

the class InterProceduralAnalysis method propagateStatisticsAcrossBlock.

// ///////////////////////////
// INTRA-PROCEDURE ANALYSIS
// ////
private void propagateStatisticsAcrossBlock(StatementBlock sb, LocalVariableMap callVars, FunctionCallSizeInfo fcallSizes, Set<String> fnStack) {
    if (sb instanceof FunctionStatementBlock) {
        FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
        FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
        for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
    } else if (sb instanceof WhileStatementBlock) {
        WhileStatementBlock wsb = (WhileStatementBlock) sb;
        WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
        // old stats into predicate
        propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
        // remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, wsb);
        // check and propagate stats into body
        LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
        for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
        if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, wsb)) {
            // second pass if required
            propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
            for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
        }
        // remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, sb);
    } else if (sb instanceof IfStatementBlock) {
        IfStatementBlock isb = (IfStatementBlock) sb;
        IfStatement istmt = (IfStatement) isb.getStatement(0);
        // old stats into predicate
        propagateStatisticsAcrossPredicateDAG(isb.getPredicateHops(), callVars);
        // check and propagate stats into body
        LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
        LocalVariableMap callVarsElse = (LocalVariableMap) callVars.clone();
        for (StatementBlock sbi : istmt.getIfBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
        for (StatementBlock sbi : istmt.getElseBody()) propagateStatisticsAcrossBlock(sbi, callVarsElse, fcallSizes, fnStack);
        callVars = Recompiler.reconcileUpdatedCallVarsIf(oldCallVars, callVars, callVarsElse, isb);
        // remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, sb);
    } else if (// incl parfor
    sb instanceof ForStatementBlock) {
        ForStatementBlock fsb = (ForStatementBlock) sb;
        ForStatement fstmt = (ForStatement) fsb.getStatement(0);
        // old stats into predicate
        propagateStatisticsAcrossPredicateDAG(fsb.getFromHops(), callVars);
        propagateStatisticsAcrossPredicateDAG(fsb.getToHops(), callVars);
        propagateStatisticsAcrossPredicateDAG(fsb.getIncrementHops(), callVars);
        // remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, fsb);
        // check and propagate stats into body
        LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
        for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
        if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, fsb))
            for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
        // remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, sb);
    } else // generic (last-level)
    {
        // remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, sb);
        // old stats in, new stats out if updated
        ArrayList<Hop> roots = sb.getHops();
        DMLProgram prog = sb.getDMLProg();
        // replace scalar reads with literals
        Hop.resetVisitStatus(roots);
        propagateScalarsAcrossDAG(roots, callVars);
        // refresh stats across dag
        Hop.resetVisitStatus(roots);
        propagateStatisticsAcrossDAG(roots, callVars);
        // propagate stats into function calls
        Hop.resetVisitStatus(roots);
        propagateStatisticsIntoFunctions(prog, roots, callVars, fcallSizes, fnStack);
    }
}
Also used : ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) ExternalFunctionStatement(org.apache.sysml.parser.ExternalFunctionStatement) FunctionStatement(org.apache.sysml.parser.FunctionStatement) IfStatement(org.apache.sysml.parser.IfStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) LocalVariableMap(org.apache.sysml.runtime.controlprogram.LocalVariableMap) ArrayList(java.util.ArrayList) DMLProgram(org.apache.sysml.parser.DMLProgram) WhileStatement(org.apache.sysml.parser.WhileStatement) ForStatement(org.apache.sysml.parser.ForStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 54 with WhileStatementBlock

use of org.apache.sysml.parser.WhileStatementBlock in project incubator-systemml by apache.

the class Recompiler method recompileProgramBlockInstructions.

/**
 * This method does NO full program block recompile (no stats update, no rewrites, no recursion) but
 * only regenerates lops and instructions. The primary use case is recompilation after are hop configuration
 * changes which allows to preserve statistics (e.g., propagated worst case stats from other program blocks)
 * and better performance for recompiling individual program blocks.
 *
 * @param pb program block
 * @throws IOException if IOException occurs
 */
public static void recompileProgramBlockInstructions(ProgramBlock pb) throws IOException {
    if (pb instanceof WhileProgramBlock) {
        // recompile while predicate instructions
        WhileProgramBlock wpb = (WhileProgramBlock) pb;
        WhileStatementBlock wsb = (WhileStatementBlock) pb.getStatementBlock();
        if (wsb != null && wsb.getPredicateHops() != null)
            wpb.setPredicate(recompileHopsDagInstructions(wsb.getPredicateHops()));
    } else if (pb instanceof IfProgramBlock) {
        // recompile if predicate instructions
        IfProgramBlock ipb = (IfProgramBlock) pb;
        IfStatementBlock isb = (IfStatementBlock) pb.getStatementBlock();
        if (isb != null && isb.getPredicateHops() != null)
            ipb.setPredicate(recompileHopsDagInstructions(isb.getPredicateHops()));
    } else if (pb instanceof ForProgramBlock) {
        // recompile for/parfor predicate instructions
        ForProgramBlock fpb = (ForProgramBlock) pb;
        ForStatementBlock fsb = (ForStatementBlock) pb.getStatementBlock();
        if (fsb != null && fsb.getFromHops() != null)
            fpb.setFromInstructions(recompileHopsDagInstructions(fsb.getFromHops()));
        if (fsb != null && fsb.getToHops() != null)
            fpb.setToInstructions(recompileHopsDagInstructions(fsb.getToHops()));
        if (fsb != null && fsb.getIncrementHops() != null)
            fpb.setIncrementInstructions(recompileHopsDagInstructions(fsb.getIncrementHops()));
    } else {
        // recompile last-level program block instructions
        StatementBlock sb = pb.getStatementBlock();
        if (sb != null && sb.getHops() != null) {
            pb.setInstructions(recompileHopsDagInstructions(sb, sb.getHops()));
        }
    }
}
Also used : IfProgramBlock(org.apache.sysml.runtime.controlprogram.IfProgramBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) ForProgramBlock(org.apache.sysml.runtime.controlprogram.ForProgramBlock) ParForProgramBlock(org.apache.sysml.runtime.controlprogram.ParForProgramBlock) WhileProgramBlock(org.apache.sysml.runtime.controlprogram.WhileProgramBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 55 with WhileStatementBlock

use of org.apache.sysml.parser.WhileStatementBlock in project incubator-systemml by apache.

the class ProgramRewriter method rRewriteStatementBlock.

public ArrayList<StatementBlock> rRewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status, boolean splitDags) {
    ArrayList<StatementBlock> ret = new ArrayList<>();
    ret.add(sb);
    // recursive invocation
    if (sb instanceof FunctionStatementBlock) {
        FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
        FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
        fstmt.setBody(rRewriteStatementBlocks(fstmt.getBody(), status, splitDags));
    } else if (sb instanceof WhileStatementBlock) {
        WhileStatementBlock wsb = (WhileStatementBlock) sb;
        WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
        wstmt.setBody(rRewriteStatementBlocks(wstmt.getBody(), status, splitDags));
    } else if (sb instanceof IfStatementBlock) {
        IfStatementBlock isb = (IfStatementBlock) sb;
        IfStatement istmt = (IfStatement) isb.getStatement(0);
        istmt.setIfBody(rRewriteStatementBlocks(istmt.getIfBody(), status, splitDags));
        istmt.setElseBody(rRewriteStatementBlocks(istmt.getElseBody(), status, splitDags));
    } else if (sb instanceof ForStatementBlock) {
        // incl parfor
        // maintain parfor context information (e.g., for checkpointing)
        boolean prestatus = status.isInParforContext();
        if (sb instanceof ParForStatementBlock)
            status.setInParforContext(true);
        ForStatementBlock fsb = (ForStatementBlock) sb;
        ForStatement fstmt = (ForStatement) fsb.getStatement(0);
        fstmt.setBody(rRewriteStatementBlocks(fstmt.getBody(), status, splitDags));
        status.setInParforContext(prestatus);
    }
    // apply rewrite rules to individual statement blocks
    for (StatementBlockRewriteRule r : _sbRuleSet) {
        if (!splitDags && r.createsSplitDag())
            continue;
        ArrayList<StatementBlock> tmp = new ArrayList<>();
        for (StatementBlock sbc : ret) tmp.addAll(r.rewriteStatementBlock(sbc, status));
        // take over set of rewritten sbs
        ret.clear();
        ret.addAll(tmp);
    }
    return ret;
}
Also used : ParForStatementBlock(org.apache.sysml.parser.ParForStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) ArrayList(java.util.ArrayList) WhileStatement(org.apache.sysml.parser.WhileStatement) FunctionStatement(org.apache.sysml.parser.FunctionStatement) IfStatement(org.apache.sysml.parser.IfStatement) ParForStatementBlock(org.apache.sysml.parser.ParForStatementBlock) ForStatement(org.apache.sysml.parser.ForStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) ParForStatementBlock(org.apache.sysml.parser.ParForStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Aggregations

WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)72 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)66 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)62 StatementBlock (org.apache.sysml.parser.StatementBlock)62 ForStatement (org.apache.sysml.parser.ForStatement)37 IfStatement (org.apache.sysml.parser.IfStatement)36 WhileStatement (org.apache.sysml.parser.WhileStatement)35 FunctionStatementBlock (org.apache.sysml.parser.FunctionStatementBlock)32 Hop (org.apache.sysml.hops.Hop)29 WhileProgramBlock (org.apache.sysml.runtime.controlprogram.WhileProgramBlock)26 ArrayList (java.util.ArrayList)25 ForProgramBlock (org.apache.sysml.runtime.controlprogram.ForProgramBlock)24 IfProgramBlock (org.apache.sysml.runtime.controlprogram.IfProgramBlock)24 FunctionStatement (org.apache.sysml.parser.FunctionStatement)23 FunctionProgramBlock (org.apache.sysml.runtime.controlprogram.FunctionProgramBlock)16 ProgramBlock (org.apache.sysml.runtime.controlprogram.ProgramBlock)14 ParForStatementBlock (org.apache.sysml.parser.ParForStatementBlock)12 LocalVariableMap (org.apache.sysml.runtime.controlprogram.LocalVariableMap)12 Instruction (org.apache.sysml.runtime.instructions.Instruction)10 ExternalFunctionStatement (org.apache.sysml.parser.ExternalFunctionStatement)9