use of org.apache.sysml.parser.ForStatement in project incubator-systemml by apache.
the class OptimizationWrapper method optimize.
@SuppressWarnings("unused")
private static void optimize(POptMode otype, int ck, double cm, ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, boolean monitor) {
Timing time = new Timing(true);
// maintain statistics
if (DMLScript.STATISTICS)
Statistics.incrementParForOptimCount();
// create specified optimizer
Optimizer opt = createOptimizer(otype);
CostModelType cmtype = opt.getCostModelType();
LOG.trace("ParFOR Opt: Created optimizer (" + otype + "," + opt.getPlanInputType() + "," + opt.getCostModelType());
OptTree tree = null;
// recompile parfor body
if (ConfigurationManager.isDynamicRecompilation()) {
ForStatement fs = (ForStatement) sb.getStatement(0);
// debug output before recompilation
if (LOG.isDebugEnabled()) {
try {
tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec);
LOG.debug("ParFOR Opt: Input plan (before recompilation):\n" + tree.explain(false));
OptTreeConverter.clear();
} catch (Exception ex) {
throw new DMLRuntimeException("Unable to create opt tree.", ex);
}
}
// separate propagation required because recompile in-place without literal replacement)
try {
LocalVariableMap constVars = ProgramRecompiler.getReusableScalarVariables(sb.getDMLProg(), sb, ec.getVariables());
ProgramRecompiler.replaceConstantScalarVariables(sb, constVars);
} catch (Exception ex) {
throw new DMLRuntimeException(ex);
}
// program rewrites (e.g., constant folding, branch removal) according to replaced literals
try {
ProgramRewriter rewriter = createProgramRewriterWithRuleSets();
ProgramRewriteStatus state = new ProgramRewriteStatus();
rewriter.rRewriteStatementBlockHopDAGs(sb, state);
fs.setBody(rewriter.rRewriteStatementBlocks(fs.getBody(), state, true));
if (state.getRemovedBranches()) {
LOG.debug("ParFOR Opt: Removed branches during program rewrites, rebuilding runtime program");
pb.setChildBlocks(ProgramRecompiler.generatePartitialRuntimeProgram(pb.getProgram(), fs.getBody()));
}
} catch (Exception ex) {
throw new DMLRuntimeException(ex);
}
// recompilation of parfor body and called functions (if safe)
try {
// core parfor body recompilation (based on symbol table entries)
// * clone of variables in order to allow for statistics propagation across DAGs
// (tid=0, because deep copies created after opt)
LocalVariableMap tmp = (LocalVariableMap) ec.getVariables().clone();
ResetType reset = ConfigurationManager.isCodegenEnabled() ? ResetType.RESET_KNOWN_DIMS : ResetType.RESET;
Recompiler.recompileProgramBlockHierarchy(pb.getChildBlocks(), tmp, 0, reset);
// inter-procedural optimization (based on previous recompilation)
if (pb.hasFunctions()) {
InterProceduralAnalysis ipa = new InterProceduralAnalysis(sb);
Set<String> fcand = ipa.analyzeSubProgram();
if (!fcand.isEmpty()) {
// regenerate runtime program of modified functions
for (String func : fcand) {
String[] funcparts = DMLProgram.splitFunctionKey(func);
FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(funcparts[0], funcparts[1]);
// reset recompilation flags according to recompileOnce because it is only safe if function is recompileOnce
// because then recompiled for every execution (otherwise potential issues if func also called outside parfor)
ResetType reset2 = fpb.isRecompileOnce() ? reset : ResetType.NO_RESET;
Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new LocalVariableMap(), 0, reset2);
}
}
}
} catch (Exception ex) {
throw new DMLRuntimeException(ex);
}
}
// create opt tree (before optimization)
try {
tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec);
LOG.debug("ParFOR Opt: Input plan (before optimization):\n" + tree.explain(false));
} catch (Exception ex) {
throw new DMLRuntimeException("Unable to create opt tree.", ex);
}
// create cost estimator
CostEstimator est = createCostEstimator(cmtype, ec.getVariables());
LOG.trace("ParFOR Opt: Created cost estimator (" + cmtype + ")");
// core optimize
opt.optimize(sb, pb, tree, est, ec);
LOG.debug("ParFOR Opt: Optimized plan (after optimization): \n" + tree.explain(false));
// assert plan correctness
if (CHECK_PLAN_CORRECTNESS && LOG.isDebugEnabled()) {
try {
OptTreePlanChecker.checkProgramCorrectness(pb, sb, new HashSet<String>());
LOG.debug("ParFOR Opt: Checked plan and program correctness.");
} catch (Exception ex) {
throw new DMLRuntimeException("Failed to check program correctness.", ex);
}
}
long ltime = (long) time.stop();
LOG.trace("ParFOR Opt: Optimized plan in " + ltime + "ms.");
if (DMLScript.STATISTICS)
Statistics.incrementParForOptimTime(ltime);
// cleanup phase
OptTreeConverter.clear();
// monitor stats
if (monitor) {
StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_OPTIMIZER, otype.ordinal());
StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_NUMTPLANS, opt.getNumTotalPlans());
StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_NUMEPLANS, opt.getNumEvaluatedPlans());
}
}
use of org.apache.sysml.parser.ForStatement in project incubator-systemml by apache.
the class ProgramRecompiler method replaceConstantScalarVariables.
public static void replaceConstantScalarVariables(StatementBlock sb, LocalVariableMap vars) {
if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement is = (IfStatement) sb.getStatement(0);
replacePredicateLiterals(isb.getPredicateHops(), vars);
for (StatementBlock lsb : is.getIfBody()) replaceConstantScalarVariables(lsb, vars);
for (StatementBlock lsb : is.getElseBody()) replaceConstantScalarVariables(lsb, vars);
} else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement ws = (WhileStatement) sb.getStatement(0);
replacePredicateLiterals(wsb.getPredicateHops(), vars);
for (StatementBlock lsb : ws.getBody()) replaceConstantScalarVariables(lsb, vars);
} else if (// for or parfor
sb instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fs = (ForStatement) fsb.getStatement(0);
replacePredicateLiterals(fsb.getFromHops(), vars);
replacePredicateLiterals(fsb.getToHops(), vars);
replacePredicateLiterals(fsb.getIncrementHops(), vars);
for (StatementBlock lsb : fs.getBody()) replaceConstantScalarVariables(lsb, vars);
} else // last level block
{
ArrayList<Hop> hops = sb.getHops();
if (hops != null) {
// replace constant literals
Hop.resetVisitStatus(hops);
for (Hop hopRoot : hops) Recompiler.rReplaceLiterals(hopRoot, vars, true);
}
}
}
use of org.apache.sysml.parser.ForStatement in project incubator-systemml by apache.
the class Explain method getHopDAG.
private static StringBuilder getHopDAG(StatementBlock sb, StringBuilder nodes, ArrayList<Integer> lines, boolean withSubgraph) {
StringBuilder builder = new StringBuilder();
if (sb instanceof WhileStatementBlock) {
addSubGraphHeader(builder, withSubgraph);
WhileStatementBlock wsb = (WhileStatementBlock) sb;
String label = null;
if (!wsb.getUpdateInPlaceVars().isEmpty())
label = "WHILE (lines " + wsb.getBeginLine() + "-" + wsb.getEndLine() + ") in-place=" + wsb.getUpdateInPlaceVars().toString() + "";
else
label = "WHILE (lines " + wsb.getBeginLine() + "-" + wsb.getEndLine() + ")";
// TODO: Don't show predicate hops for now
// builder.append(explainHop(wsb.getPredicateHops()));
WhileStatement ws = (WhileStatement) sb.getStatement(0);
for (StatementBlock current : ws.getBody()) builder.append(getHopDAG(current, nodes, lines, withSubgraph));
addSubGraphFooter(builder, withSubgraph, label);
} else if (sb instanceof IfStatementBlock) {
addSubGraphHeader(builder, withSubgraph);
IfStatementBlock ifsb = (IfStatementBlock) sb;
String label = "IF (lines " + ifsb.getBeginLine() + "-" + ifsb.getEndLine() + ")";
// TODO: Don't show predicate hops for now
// builder.append(explainHop(ifsb.getPredicateHops(), level+1));
IfStatement ifs = (IfStatement) sb.getStatement(0);
for (StatementBlock current : ifs.getIfBody()) {
builder.append(getHopDAG(current, nodes, lines, withSubgraph));
addSubGraphFooter(builder, withSubgraph, label);
}
if (!ifs.getElseBody().isEmpty()) {
addSubGraphHeader(builder, withSubgraph);
label = "ELSE (lines " + ifsb.getBeginLine() + "-" + ifsb.getEndLine() + ")";
for (StatementBlock current : ifs.getElseBody()) builder.append(getHopDAG(current, nodes, lines, withSubgraph));
addSubGraphFooter(builder, withSubgraph, label);
}
} else if (sb instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) sb;
addSubGraphHeader(builder, withSubgraph);
String label = "";
if (sb instanceof ParForStatementBlock) {
if (!fsb.getUpdateInPlaceVars().isEmpty())
label = "PARFOR (lines " + fsb.getBeginLine() + "-" + fsb.getEndLine() + ") in-place=" + fsb.getUpdateInPlaceVars().toString() + "";
else
label = "PARFOR (lines " + fsb.getBeginLine() + "-" + fsb.getEndLine() + ")";
} else {
if (!fsb.getUpdateInPlaceVars().isEmpty())
label = "FOR (lines " + fsb.getBeginLine() + "-" + fsb.getEndLine() + ") in-place=" + fsb.getUpdateInPlaceVars().toString() + "";
else
label = "FOR (lines " + fsb.getBeginLine() + "-" + fsb.getEndLine() + ")";
}
// TODO: Don't show predicate hops for now
// if (fsb.getFromHops() != null)
// builder.append(explainHop(fsb.getFromHops(), level+1));
// if (fsb.getToHops() != null)
// builder.append(explainHop(fsb.getToHops(), level+1));
// if (fsb.getIncrementHops() != null)
// builder.append(explainHop(fsb.getIncrementHops(), level+1));
ForStatement fs = (ForStatement) sb.getStatement(0);
for (StatementBlock current : fs.getBody()) builder.append(getHopDAG(current, nodes, lines, withSubgraph));
addSubGraphFooter(builder, withSubgraph, label);
} else if (sb instanceof FunctionStatementBlock) {
FunctionStatement fsb = (FunctionStatement) sb.getStatement(0);
addSubGraphHeader(builder, withSubgraph);
String label = "Function (lines " + fsb.getBeginLine() + "-" + fsb.getEndLine() + ")";
for (StatementBlock current : fsb.getBody()) builder.append(getHopDAG(current, nodes, lines, withSubgraph));
addSubGraphFooter(builder, withSubgraph, label);
} else {
// For generic StatementBlock
if (sb.requiresRecompilation()) {
addSubGraphHeader(builder, withSubgraph);
}
ArrayList<Hop> hopsDAG = sb.getHops();
if (hopsDAG != null && !hopsDAG.isEmpty()) {
Hop.resetVisitStatus(hopsDAG);
for (Hop hop : hopsDAG) builder.append(getHopDAG(hop, nodes, lines, withSubgraph));
Hop.resetVisitStatus(hopsDAG);
}
if (sb.requiresRecompilation()) {
builder.append("style=filled;\n");
builder.append("color=lightgrey;\n");
String label = "(lines " + sb.getBeginLine() + "-" + sb.getEndLine() + ") [recompile=" + sb.requiresRecompilation() + "]";
addSubGraphFooter(builder, withSubgraph, label);
}
}
return builder;
}
use of org.apache.sysml.parser.ForStatement in project incubator-systemml by apache.
the class InterProceduralAnalysis method propagateStatisticsAcrossBlock.
/////////////////////////////
// INTRA-PROCEDURE ANALYSIS
//////
/**
* Perform intra-procedural analysis (IPA) by propagating statistics
* across statement blocks.
*
* @param sb DML statement blocks.
* @param fcand Function candidates.
* @param callVars Map of variables eligible for propagation.
* @param fcandSafeNNZ Function candidate safe non-zeros.
* @param unaryFcands Unary function candidates.
* @param fnStack Function stack to determine current scope.
* @throws HopsException If a HopsException occurs.
* @throws ParseException If a ParseException occurs.
*/
private void propagateStatisticsAcrossBlock(StatementBlock sb, Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> fcandSafeNNZ, Set<String> unaryFcands, Set<String> fnStack) throws HopsException, ParseException {
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
} else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
//old stats into predicate
propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
//remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, wsb);
//check and propagate stats into body
LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, wsb)) {
//second pass if required
propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
}
//remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
} else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
//old stats into predicate
propagateStatisticsAcrossPredicateDAG(isb.getPredicateHops(), callVars);
//check and propagate stats into body
LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
LocalVariableMap callVarsElse = (LocalVariableMap) callVars.clone();
for (StatementBlock sbi : istmt.getIfBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
for (StatementBlock sbi : istmt.getElseBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVarsElse, fcandSafeNNZ, unaryFcands, fnStack);
callVars = Recompiler.reconcileUpdatedCallVarsIf(oldCallVars, callVars, callVarsElse, isb);
//remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
} else if (//incl parfor
sb instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
//old stats into predicate
propagateStatisticsAcrossPredicateDAG(fsb.getFromHops(), callVars);
propagateStatisticsAcrossPredicateDAG(fsb.getToHops(), callVars);
propagateStatisticsAcrossPredicateDAG(fsb.getIncrementHops(), callVars);
//remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, fsb);
//check and propagate stats into body
LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, fsb))
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
//remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
} else //generic (last-level)
{
//remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
//old stats in, new stats out if updated
ArrayList<Hop> roots = sb.get_hops();
DMLProgram prog = sb.getDMLProg();
//replace scalar reads with literals
Hop.resetVisitStatus(roots);
propagateScalarsAcrossDAG(roots, callVars);
//refresh stats across dag
Hop.resetVisitStatus(roots);
propagateStatisticsAcrossDAG(roots, callVars);
//propagate stats into function calls
Hop.resetVisitStatus(roots);
propagateStatisticsIntoFunctions(prog, roots, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
}
}
Aggregations