use of org.apache.sysml.parser.ForStatement in project incubator-systemml by apache.
the class ProgramRewriter method rewriteStatementBlock.
private ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status) throws HopsException {
ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>();
ret.add(sb);
//recursive invocation
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
fstmt.setBody(rewriteStatementBlocks(fstmt.getBody(), status));
} else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
wstmt.setBody(rewriteStatementBlocks(wstmt.getBody(), status));
} else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
istmt.setIfBody(rewriteStatementBlocks(istmt.getIfBody(), status));
istmt.setElseBody(rewriteStatementBlocks(istmt.getElseBody(), status));
} else if (//incl parfor
sb instanceof ForStatementBlock) {
//maintain parfor context information (e.g., for checkpointing)
boolean prestatus = status.isInParforContext();
if (sb instanceof ParForStatementBlock)
status.setInParforContext(true);
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
fstmt.setBody(rewriteStatementBlocks(fstmt.getBody(), status));
status.setInParforContext(prestatus);
}
//apply rewrite rules
for (StatementBlockRewriteRule r : _sbRuleSet) {
ArrayList<StatementBlock> tmp = new ArrayList<StatementBlock>();
for (StatementBlock sbc : ret) tmp.addAll(r.rewriteStatementBlock(sbc, status));
//take over set of rewritten sbs
ret.clear();
ret.addAll(tmp);
}
return ret;
}
use of org.apache.sysml.parser.ForStatement in project incubator-systemml by apache.
the class InterProceduralAnalysis method getFunctionCandidatesForStatisticPropagation.
/////////////////////////////
// GET FUNCTION CANDIDATES
//////
private void getFunctionCandidatesForStatisticPropagation(StatementBlock sb, Map<String, Integer> fcandCounts, Map<String, FunctionOp> fcandHops) throws HopsException, ParseException {
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (StatementBlock sbi : fstmt.getBody()) getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops);
} else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
for (StatementBlock sbi : wstmt.getBody()) getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops);
} else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
for (StatementBlock sbi : istmt.getIfBody()) getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops);
for (StatementBlock sbi : istmt.getElseBody()) getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops);
} else if (//incl parfor
sb instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
for (StatementBlock sbi : fstmt.getBody()) getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops);
} else //generic (last-level)
{
ArrayList<Hop> roots = sb.get_hops();
if (//empty statement blocks
roots != null)
for (Hop root : roots) getFunctionCandidatesForStatisticPropagation(sb.getDMLProg(), root, fcandCounts, fcandHops);
}
}
use of org.apache.sysml.parser.ForStatement in project incubator-systemml by apache.
the class InterProceduralAnalysis method rRemoveConstantBinaryOp.
private void rRemoveConstantBinaryOp(StatementBlock sb, HashMap<String, Hop> mOnes) throws HopsException {
if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
for (StatementBlock c : istmt.getIfBody()) rRemoveConstantBinaryOp(c, mOnes);
if (istmt.getElseBody() != null)
for (StatementBlock c : istmt.getElseBody()) rRemoveConstantBinaryOp(c, mOnes);
} else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
for (StatementBlock c : wstmt.getBody()) rRemoveConstantBinaryOp(c, mOnes);
} else if (sb instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
for (StatementBlock c : fstmt.getBody()) rRemoveConstantBinaryOp(c, mOnes);
} else {
if (sb.get_hops() != null) {
Hop.resetVisitStatus(sb.get_hops());
for (Hop hop : sb.get_hops()) rRemoveConstantBinaryOp(hop, mOnes);
}
}
}
use of org.apache.sysml.parser.ForStatement in project incubator-systemml by apache.
the class ProgramRecompiler method replaceConstantScalarVariables.
public static void replaceConstantScalarVariables(StatementBlock sb, LocalVariableMap vars) throws DMLRuntimeException, HopsException {
if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement is = (IfStatement) sb.getStatement(0);
replacePredicateLiterals(isb.getPredicateHops(), vars);
for (StatementBlock lsb : is.getIfBody()) replaceConstantScalarVariables(lsb, vars);
for (StatementBlock lsb : is.getElseBody()) replaceConstantScalarVariables(lsb, vars);
} else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement ws = (WhileStatement) sb.getStatement(0);
replacePredicateLiterals(wsb.getPredicateHops(), vars);
for (StatementBlock lsb : ws.getBody()) replaceConstantScalarVariables(lsb, vars);
} else if (//for or parfor
sb instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fs = (ForStatement) fsb.getStatement(0);
replacePredicateLiterals(fsb.getFromHops(), vars);
replacePredicateLiterals(fsb.getToHops(), vars);
replacePredicateLiterals(fsb.getIncrementHops(), vars);
for (StatementBlock lsb : fs.getBody()) replaceConstantScalarVariables(lsb, vars);
} else //last level block
{
ArrayList<Hop> hops = sb.get_hops();
if (hops != null) {
//replace constant literals
Hop.resetVisitStatus(hops);
for (Hop hopRoot : hops) Recompiler.rReplaceLiterals(hopRoot, vars, true);
}
}
}
use of org.apache.sysml.parser.ForStatement in project incubator-systemml by apache.
the class OptimizationWrapper method optimize.
@SuppressWarnings("unused")
private static void optimize(POptMode otype, int ck, double cm, ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, boolean monitor) throws DMLRuntimeException {
Timing time = new Timing(true);
//maintain statistics
if (DMLScript.STATISTICS)
Statistics.incrementParForOptimCount();
//create specified optimizer
Optimizer opt = createOptimizer(otype);
CostModelType cmtype = opt.getCostModelType();
LOG.trace("ParFOR Opt: Created optimizer (" + otype + "," + opt.getPlanInputType() + "," + opt.getCostModelType());
OptTree tree = null;
//recompile parfor body
if (ConfigurationManager.isDynamicRecompilation()) {
ForStatement fs = (ForStatement) sb.getStatement(0);
//debug output before recompilation
if (LOG.isDebugEnabled()) {
try {
tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec);
LOG.debug("ParFOR Opt: Input plan (before recompilation):\n" + tree.explain(false));
OptTreeConverter.clear();
} catch (Exception ex) {
throw new DMLRuntimeException("Unable to create opt tree.", ex);
}
}
//separate propagation required because recompile in-place without literal replacement)
try {
LocalVariableMap constVars = ProgramRecompiler.getReusableScalarVariables(sb.getDMLProg(), sb, ec.getVariables());
ProgramRecompiler.replaceConstantScalarVariables(sb, constVars);
} catch (Exception ex) {
throw new DMLRuntimeException(ex);
}
//program rewrites (e.g., constant folding, branch removal) according to replaced literals
try {
ProgramRewriter rewriter = createProgramRewriterWithRuleSets();
ProgramRewriteStatus state = new ProgramRewriteStatus();
rewriter.rewriteStatementBlockHopDAGs(sb, state);
fs.setBody(rewriter.rewriteStatementBlocks(fs.getBody(), state));
if (state.getRemovedBranches()) {
LOG.debug("ParFOR Opt: Removed branches during program rewrites, rebuilding runtime program");
pb.setChildBlocks(ProgramRecompiler.generatePartitialRuntimeProgram(pb.getProgram(), fs.getBody()));
}
} catch (Exception ex) {
throw new DMLRuntimeException(ex);
}
//recompilation of parfor body and called functions (if safe)
try {
//core parfor body recompilation (based on symbol table entries)
//* clone of variables in order to allow for statistics propagation across DAGs
//(tid=0, because deep copies created after opt)
LocalVariableMap tmp = (LocalVariableMap) ec.getVariables().clone();
Recompiler.recompileProgramBlockHierarchy(pb.getChildBlocks(), tmp, 0, true);
//inter-procedural optimization (based on previous recompilation)
if (pb.hasFunctions()) {
InterProceduralAnalysis ipa = new InterProceduralAnalysis();
Set<String> fcand = ipa.analyzeSubProgram(sb);
if (!fcand.isEmpty()) {
//regenerate runtime program of modified functions
for (String func : fcand) {
String[] funcparts = DMLProgram.splitFunctionKey(func);
FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(funcparts[0], funcparts[1]);
//reset recompilation flags according to recompileOnce because it is only safe if function is recompileOnce
//because then recompiled for every execution (otherwise potential issues if func also called outside parfor)
Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new LocalVariableMap(), 0, fpb.isRecompileOnce());
}
}
}
} catch (Exception ex) {
throw new DMLRuntimeException(ex);
}
}
//create opt tree (before optimization)
try {
tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec);
LOG.debug("ParFOR Opt: Input plan (before optimization):\n" + tree.explain(false));
} catch (Exception ex) {
throw new DMLRuntimeException("Unable to create opt tree.", ex);
}
//create cost estimator
CostEstimator est = createCostEstimator(cmtype, ec.getVariables());
LOG.trace("ParFOR Opt: Created cost estimator (" + cmtype + ")");
//core optimize
opt.optimize(sb, pb, tree, est, ec);
LOG.debug("ParFOR Opt: Optimized plan (after optimization): \n" + tree.explain(false));
//assert plan correctness
if (CHECK_PLAN_CORRECTNESS && LOG.isDebugEnabled()) {
try {
OptTreePlanChecker.checkProgramCorrectness(pb, sb, new HashSet<String>());
LOG.debug("ParFOR Opt: Checked plan and program correctness.");
} catch (Exception ex) {
throw new DMLRuntimeException("Failed to check program correctness.", ex);
}
}
long ltime = (long) time.stop();
LOG.trace("ParFOR Opt: Optimized plan in " + ltime + "ms.");
if (DMLScript.STATISTICS)
Statistics.incrementParForOptimTime(ltime);
//cleanup phase
OptTreeConverter.clear();
//monitor stats
if (monitor) {
StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_OPTIMIZER, otype.ordinal());
StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_NUMTPLANS, opt.getNumTotalPlans());
StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_NUMEPLANS, opt.getNumEvaluatedPlans());
}
}
Aggregations