use of org.apache.sysml.hops.rewrite.ProgramRewriter in project incubator-systemml by apache.
the class Connection method prepareScript.
/**
* Prepares (precompiles) a script, sets input parameter values, and registers input and output variables.
*
* @param script string representing the DML or PyDML script
* @param args map of input parameters ($) and their values
* @param inputs string array of input variables to register
* @param outputs string array of output variables to register
* @param parsePyDML {@code true} if PyDML, {@code false} if DML
* @return PreparedScript object representing the precompiled script
* @throws DMLException if DMLException occurs
*/
public PreparedScript prepareScript(String script, Map<String, String> args, String[] inputs, String[] outputs, boolean parsePyDML) throws DMLException {
DMLScript.SCRIPT_TYPE = parsePyDML ? ScriptType.PYDML : ScriptType.DML;
//prepare arguments
//simplified compilation chain
Program rtprog = null;
try {
//parsing
ParserWrapper parser = ParserFactory.createParser(parsePyDML ? ScriptType.PYDML : ScriptType.DML);
DMLProgram prog = parser.parse(null, script, args);
//language validate
DMLTranslator dmlt = new DMLTranslator(prog);
dmlt.liveVariableAnalysis(prog);
dmlt.validateParseTree(prog);
//hop construct/rewrite
dmlt.constructHops(prog);
dmlt.rewriteHopsDAG(prog);
//rewrite persistent reads/writes
RewriteRemovePersistentReadWrite rewrite = new RewriteRemovePersistentReadWrite(inputs, outputs);
ProgramRewriter rewriter2 = new ProgramRewriter(rewrite);
rewriter2.rewriteProgramHopDAGs(prog);
//lop construct and runtime prog generation
dmlt.constructLops(prog);
rtprog = prog.getRuntimeProgram(_dmlconf);
//final cleanup runtime prog
JMLCUtils.cleanupRuntimeProgram(rtprog, outputs);
//System.out.println(Explain.explain(rtprog));
} catch (ParseException pe) {
// don't chain ParseException (for cleaner error output)
throw pe;
} catch (Exception ex) {
throw new DMLException(ex);
}
//return newly create precompiled script
return new PreparedScript(rtprog, inputs, outputs);
}
use of org.apache.sysml.hops.rewrite.ProgramRewriter in project incubator-systemml by apache.
the class ScriptExecutor method rewritePersistentReadsAndWrites.
/**
* Replace persistent reads and writes with transient reads and writes in
* the symbol table.
*/
protected void rewritePersistentReadsAndWrites() {
LocalVariableMap symbolTable = script.getSymbolTable();
if (symbolTable != null) {
String[] inputs = (script.getInputVariables() == null) ? new String[0] : script.getInputVariables().toArray(new String[0]);
String[] outputs = (script.getOutputVariables() == null) ? new String[0] : script.getOutputVariables().toArray(new String[0]);
RewriteRemovePersistentReadWrite rewrite = new RewriteRemovePersistentReadWrite(inputs, outputs, script.getSymbolTable());
ProgramRewriter programRewriter = new ProgramRewriter(rewrite);
try {
programRewriter.rewriteProgramHopDAGs(dmlProgram);
} catch (LanguageException e) {
throw new MLContextException("Exception occurred while rewriting persistent reads and writes", e);
} catch (HopsException e) {
throw new MLContextException("Exception occurred while rewriting persistent reads and writes", e);
}
}
}
use of org.apache.sysml.hops.rewrite.ProgramRewriter in project incubator-systemml by apache.
the class OptimizerRuleBased method rewriteInjectSparkLoopCheckpointing.
///////
//REWRITE inject spark loop checkpointing
///
protected void rewriteInjectSparkLoopCheckpointing(OptNode n) throws DMLRuntimeException {
//get program blocks of root parfor
Object[] progobj = OptTreeConverter.getAbstractPlanMapping().getMappedProg(n.getID());
ParForStatementBlock pfsb = (ParForStatementBlock) progobj[0];
ParForStatement fs = (ParForStatement) pfsb.getStatement(0);
ParForProgramBlock pfpb = (ParForProgramBlock) progobj[1];
boolean applied = false;
try {
//apply hop rewrite inject spark checkpoints (but without context awareness)
RewriteInjectSparkLoopCheckpointing rewrite = new RewriteInjectSparkLoopCheckpointing(false);
ProgramRewriter rewriter = new ProgramRewriter(rewrite);
ProgramRewriteStatus state = new ProgramRewriteStatus();
rewriter.rewriteStatementBlockHopDAGs(pfsb, state);
fs.setBody(rewriter.rewriteStatementBlocks(fs.getBody(), state));
//recompile if additional checkpoints introduced
if (state.getInjectedCheckpoints()) {
pfpb.setChildBlocks(ProgramRecompiler.generatePartitialRuntimeProgram(pfpb.getProgram(), fs.getBody()));
applied = true;
}
} catch (Exception ex) {
throw new DMLRuntimeException(ex);
}
LOG.debug(getOptMode() + " OPT: rewrite 'inject spark loop checkpointing' - result=" + applied);
}
use of org.apache.sysml.hops.rewrite.ProgramRewriter in project incubator-systemml by apache.
the class OptimizationWrapper method createProgramRewriterWithRuleSets.
private static ProgramRewriter createProgramRewriterWithRuleSets() {
//create hop rewrite set
ArrayList<HopRewriteRule> hRewrites = new ArrayList<HopRewriteRule>();
hRewrites.add(new RewriteConstantFolding());
//create statementblock rewrite set
ArrayList<StatementBlockRewriteRule> sbRewrites = new ArrayList<StatementBlockRewriteRule>();
sbRewrites.add(new RewriteRemoveUnnecessaryBranches());
ProgramRewriter rewriter = new ProgramRewriter(hRewrites, sbRewrites);
return rewriter;
}
use of org.apache.sysml.hops.rewrite.ProgramRewriter in project incubator-systemml by apache.
the class OptimizationWrapper method optimize.
@SuppressWarnings("unused")
private static void optimize(POptMode otype, int ck, double cm, ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, boolean monitor) throws DMLRuntimeException {
Timing time = new Timing(true);
//maintain statistics
if (DMLScript.STATISTICS)
Statistics.incrementParForOptimCount();
//create specified optimizer
Optimizer opt = createOptimizer(otype);
CostModelType cmtype = opt.getCostModelType();
LOG.trace("ParFOR Opt: Created optimizer (" + otype + "," + opt.getPlanInputType() + "," + opt.getCostModelType());
OptTree tree = null;
//recompile parfor body
if (ConfigurationManager.isDynamicRecompilation()) {
ForStatement fs = (ForStatement) sb.getStatement(0);
//debug output before recompilation
if (LOG.isDebugEnabled()) {
try {
tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec);
LOG.debug("ParFOR Opt: Input plan (before recompilation):\n" + tree.explain(false));
OptTreeConverter.clear();
} catch (Exception ex) {
throw new DMLRuntimeException("Unable to create opt tree.", ex);
}
}
//separate propagation required because recompile in-place without literal replacement)
try {
LocalVariableMap constVars = ProgramRecompiler.getReusableScalarVariables(sb.getDMLProg(), sb, ec.getVariables());
ProgramRecompiler.replaceConstantScalarVariables(sb, constVars);
} catch (Exception ex) {
throw new DMLRuntimeException(ex);
}
//program rewrites (e.g., constant folding, branch removal) according to replaced literals
try {
ProgramRewriter rewriter = createProgramRewriterWithRuleSets();
ProgramRewriteStatus state = new ProgramRewriteStatus();
rewriter.rewriteStatementBlockHopDAGs(sb, state);
fs.setBody(rewriter.rewriteStatementBlocks(fs.getBody(), state));
if (state.getRemovedBranches()) {
LOG.debug("ParFOR Opt: Removed branches during program rewrites, rebuilding runtime program");
pb.setChildBlocks(ProgramRecompiler.generatePartitialRuntimeProgram(pb.getProgram(), fs.getBody()));
}
} catch (Exception ex) {
throw new DMLRuntimeException(ex);
}
//recompilation of parfor body and called functions (if safe)
try {
//core parfor body recompilation (based on symbol table entries)
//* clone of variables in order to allow for statistics propagation across DAGs
//(tid=0, because deep copies created after opt)
LocalVariableMap tmp = (LocalVariableMap) ec.getVariables().clone();
Recompiler.recompileProgramBlockHierarchy(pb.getChildBlocks(), tmp, 0, true);
//inter-procedural optimization (based on previous recompilation)
if (pb.hasFunctions()) {
InterProceduralAnalysis ipa = new InterProceduralAnalysis();
Set<String> fcand = ipa.analyzeSubProgram(sb);
if (!fcand.isEmpty()) {
//regenerate runtime program of modified functions
for (String func : fcand) {
String[] funcparts = DMLProgram.splitFunctionKey(func);
FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(funcparts[0], funcparts[1]);
//reset recompilation flags according to recompileOnce because it is only safe if function is recompileOnce
//because then recompiled for every execution (otherwise potential issues if func also called outside parfor)
Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new LocalVariableMap(), 0, fpb.isRecompileOnce());
}
}
}
} catch (Exception ex) {
throw new DMLRuntimeException(ex);
}
}
//create opt tree (before optimization)
try {
tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec);
LOG.debug("ParFOR Opt: Input plan (before optimization):\n" + tree.explain(false));
} catch (Exception ex) {
throw new DMLRuntimeException("Unable to create opt tree.", ex);
}
//create cost estimator
CostEstimator est = createCostEstimator(cmtype, ec.getVariables());
LOG.trace("ParFOR Opt: Created cost estimator (" + cmtype + ")");
//core optimize
opt.optimize(sb, pb, tree, est, ec);
LOG.debug("ParFOR Opt: Optimized plan (after optimization): \n" + tree.explain(false));
//assert plan correctness
if (CHECK_PLAN_CORRECTNESS && LOG.isDebugEnabled()) {
try {
OptTreePlanChecker.checkProgramCorrectness(pb, sb, new HashSet<String>());
LOG.debug("ParFOR Opt: Checked plan and program correctness.");
} catch (Exception ex) {
throw new DMLRuntimeException("Failed to check program correctness.", ex);
}
}
long ltime = (long) time.stop();
LOG.trace("ParFOR Opt: Optimized plan in " + ltime + "ms.");
if (DMLScript.STATISTICS)
Statistics.incrementParForOptimTime(ltime);
//cleanup phase
OptTreeConverter.clear();
//monitor stats
if (monitor) {
StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_OPTIMIZER, otype.ordinal());
StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_NUMTPLANS, opt.getNumTotalPlans());
StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_NUMEPLANS, opt.getNumEvaluatedPlans());
}
}
Aggregations