use of org.apache.sysml.parser.WhileStatementBlock in project incubator-systemml by apache.
the class RewriteCompressedReblock method rAnalyzeProgram.
private static void rAnalyzeProgram(StatementBlock sb, ProbeStatus status) {
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (StatementBlock csb : fstmt.getBody()) rAnalyzeProgram(csb, status);
} else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
for (StatementBlock csb : wstmt.getBody()) rAnalyzeProgram(csb, status);
if (wsb.variablesRead().containsAnyName(status.compMtx))
status.usedInLoop = true;
} else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
for (StatementBlock csb : istmt.getIfBody()) rAnalyzeProgram(csb, status);
for (StatementBlock csb : istmt.getElseBody()) rAnalyzeProgram(csb, status);
if (isb.variablesUpdated().containsAnyName(status.compMtx))
status.condUpdate = true;
} else if (sb instanceof ForStatementBlock) {
// incl parfor
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
for (StatementBlock csb : fstmt.getBody()) rAnalyzeProgram(csb, status);
if (fsb.variablesRead().containsAnyName(status.compMtx))
status.usedInLoop = true;
} else if (sb.getHops() != null) {
// generic (last-level)
ArrayList<Hop> roots = sb.getHops();
Hop.resetVisitStatus(roots);
// process entire HOP DAG starting from the roots
for (Hop root : roots) rAnalyzeHopDag(root, status);
// remove temporary variables
status.compMtx.removeIf(n -> n.startsWith(TMP_PREFIX));
Hop.resetVisitStatus(roots);
}
}
use of org.apache.sysml.parser.WhileStatementBlock in project incubator-systemml by apache.
the class RewriteForLoopVectorization method rewriteStatementBlock.
@Override
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
if (sb instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fs = (ForStatement) fsb.getStatement(0);
Hop from = fsb.getFromHops();
Hop to = fsb.getToHops();
Hop incr = fsb.getIncrementHops();
String iterVar = fsb.getIterPredicate().getIterVar().getName();
if (// single child block
fs.getBody() != null && fs.getBody().size() == 1) {
StatementBlock csb = (StatementBlock) fs.getBody().get(0);
if (!(// last level block
csb instanceof WhileStatementBlock || csb instanceof IfStatementBlock || csb instanceof ForStatementBlock)) {
// AUTO VECTORIZATION PATTERNS
// Note: unnecessary row or column indexing then later removed via hop rewrites
// e.g., for(i in a:b){s = s + as.scalar(X[i,2])} -> s = sum(X[a:b,2])
sb = vectorizeScalarAggregate(sb, csb, from, to, incr, iterVar);
// e.g., for(i in a:b){X[i,2] = Y[i,1] + Z[i,3]} -> X[a:b,2] = Y[a:b,1] + Z[a:b,3];
sb = vectorizeElementwiseBinary(sb, csb, from, to, incr, iterVar);
// e.g., for(i in a:b){X[i,2] = abs(Y[i,1])} -> X[a:b,2] = abs(Y[a:b,1]);
sb = vectorizeElementwiseUnary(sb, csb, from, to, incr, iterVar);
// e.g., for(i in a:b){X[7,i] = Y[1,i]} -> X[7,a:b] = Y[1,a:b];
sb = vectorizeIndexedCopy(sb, csb, from, to, incr, iterVar);
}
}
}
// that includes the equivalent vectorized operations.
return Arrays.asList(sb);
}
use of org.apache.sysml.parser.WhileStatementBlock in project incubator-systemml by apache.
the class RewriteMarkLoopVariablesUpdateInPlace method rewriteStatementBlock.
@Override
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status) {
if (DMLScript.rtplatform == RUNTIME_PLATFORM.HADOOP || DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK) {
// nothing to do here, return original statement block
return Arrays.asList(sb);
}
if (// incl parfor
sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock) {
ArrayList<String> candidates = new ArrayList<>();
VariableSet updated = sb.variablesUpdated();
VariableSet liveout = sb.liveOut();
for (String varname : updated.getVariableNames()) {
if (updated.getVariable(varname).getDataType() == DataType.MATRIX && // exclude local vars
liveout.containsVariable(varname)) {
if (sb instanceof WhileStatementBlock) {
WhileStatement wstmt = (WhileStatement) sb.getStatement(0);
if (rIsApplicableForUpdateInPlace(wstmt.getBody(), varname))
candidates.add(varname);
} else if (sb instanceof ForStatementBlock) {
ForStatement wstmt = (ForStatement) sb.getStatement(0);
if (rIsApplicableForUpdateInPlace(wstmt.getBody(), varname))
candidates.add(varname);
}
}
}
sb.setUpdateInPlaceVars(candidates);
}
// return modified statement block
return Arrays.asList(sb);
}
use of org.apache.sysml.parser.WhileStatementBlock in project incubator-systemml by apache.
the class InterProceduralAnalysis method propagateStatisticsAcrossBlock.
// ///////////////////////////
// INTRA-PROCEDURE ANALYSIS
// ////
private void propagateStatisticsAcrossBlock(StatementBlock sb, LocalVariableMap callVars, FunctionCallSizeInfo fcallSizes, Set<String> fnStack) {
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
} else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
// old stats into predicate
propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
// remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, wsb);
// check and propagate stats into body
LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, wsb)) {
// second pass if required
propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
}
// remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
} else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
// old stats into predicate
propagateStatisticsAcrossPredicateDAG(isb.getPredicateHops(), callVars);
// check and propagate stats into body
LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
LocalVariableMap callVarsElse = (LocalVariableMap) callVars.clone();
for (StatementBlock sbi : istmt.getIfBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
for (StatementBlock sbi : istmt.getElseBody()) propagateStatisticsAcrossBlock(sbi, callVarsElse, fcallSizes, fnStack);
callVars = Recompiler.reconcileUpdatedCallVarsIf(oldCallVars, callVars, callVarsElse, isb);
// remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
} else if (// incl parfor
sb instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
// old stats into predicate
propagateStatisticsAcrossPredicateDAG(fsb.getFromHops(), callVars);
propagateStatisticsAcrossPredicateDAG(fsb.getToHops(), callVars);
propagateStatisticsAcrossPredicateDAG(fsb.getIncrementHops(), callVars);
// remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, fsb);
// check and propagate stats into body
LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, fsb))
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
// remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
} else // generic (last-level)
{
// remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
// old stats in, new stats out if updated
ArrayList<Hop> roots = sb.getHops();
DMLProgram prog = sb.getDMLProg();
// replace scalar reads with literals
Hop.resetVisitStatus(roots);
propagateScalarsAcrossDAG(roots, callVars);
// refresh stats across dag
Hop.resetVisitStatus(roots);
propagateStatisticsAcrossDAG(roots, callVars);
// propagate stats into function calls
Hop.resetVisitStatus(roots);
propagateStatisticsIntoFunctions(prog, roots, callVars, fcallSizes, fnStack);
}
}
use of org.apache.sysml.parser.WhileStatementBlock in project incubator-systemml by apache.
the class Recompiler method recompileProgramBlockInstructions.
/**
* This method does NO full program block recompile (no stats update, no rewrites, no recursion) but
* only regenerates lops and instructions. The primary use case is recompilation after are hop configuration
* changes which allows to preserve statistics (e.g., propagated worst case stats from other program blocks)
* and better performance for recompiling individual program blocks.
*
* @param pb program block
* @throws IOException if IOException occurs
*/
public static void recompileProgramBlockInstructions(ProgramBlock pb) throws IOException {
if (pb instanceof WhileProgramBlock) {
// recompile while predicate instructions
WhileProgramBlock wpb = (WhileProgramBlock) pb;
WhileStatementBlock wsb = (WhileStatementBlock) pb.getStatementBlock();
if (wsb != null && wsb.getPredicateHops() != null)
wpb.setPredicate(recompileHopsDagInstructions(wsb.getPredicateHops()));
} else if (pb instanceof IfProgramBlock) {
// recompile if predicate instructions
IfProgramBlock ipb = (IfProgramBlock) pb;
IfStatementBlock isb = (IfStatementBlock) pb.getStatementBlock();
if (isb != null && isb.getPredicateHops() != null)
ipb.setPredicate(recompileHopsDagInstructions(isb.getPredicateHops()));
} else if (pb instanceof ForProgramBlock) {
// recompile for/parfor predicate instructions
ForProgramBlock fpb = (ForProgramBlock) pb;
ForStatementBlock fsb = (ForStatementBlock) pb.getStatementBlock();
if (fsb != null && fsb.getFromHops() != null)
fpb.setFromInstructions(recompileHopsDagInstructions(fsb.getFromHops()));
if (fsb != null && fsb.getToHops() != null)
fpb.setToInstructions(recompileHopsDagInstructions(fsb.getToHops()));
if (fsb != null && fsb.getIncrementHops() != null)
fpb.setIncrementInstructions(recompileHopsDagInstructions(fsb.getIncrementHops()));
} else {
// recompile last-level program block instructions
StatementBlock sb = pb.getStatementBlock();
if (sb != null && sb.getHops() != null) {
pb.setInstructions(recompileHopsDagInstructions(sb, sb.getHops()));
}
}
}
Aggregations