use of org.apache.sysml.runtime.controlprogram.Program in project incubator-systemml by apache.
the class GDFEnumOptimizer method optimize.
@Override
public GDFGraph optimize(GDFGraph gdfgraph, Summary summary) throws DMLRuntimeException, HopsException, LopsException {
Timing time = new Timing(true);
Program prog = gdfgraph.getRuntimeProgram();
ExecutionContext ec = ExecutionContextFactory.createContext(prog);
ArrayList<GDFNode> roots = gdfgraph.getGraphRootNodes();
//Step 1: baseline costing for branch and bound costs
double initCosts = Double.MAX_VALUE;
if (BRANCH_AND_BOUND_PRUNING) {
initCosts = CostEstimationWrapper.getTimeEstimate(prog, ec);
initCosts = initCosts * (1 + BRANCH_AND_BOUND_REL_THRES);
}
//Step 2: dynamic programming plan generation
//(finally, pick optimal root plans over all interesting property sets)
ArrayList<Plan> rootPlans = new ArrayList<Plan>();
for (GDFNode node : roots) {
PlanSet ps = enumOpt(node, _memo, initCosts);
Plan optPlan = ps.getPlanWithMinCosts();
rootPlans.add(optPlan);
}
long enumPlanMismatch = getPlanMismatches();
//check for final containment of independent roots and pick optimal
HashMap<Long, Plan> memo = new HashMap<Long, Plan>();
resetPlanMismatches();
for (Plan p : rootPlans) rSetRuntimePlanConfig(p, memo);
long finalPlanMismatch = getPlanMismatches();
//generate final runtime plan (w/ optimal config)
Recompiler.recompileProgramBlockHierarchy(prog.getProgramBlocks(), new LocalVariableMap(), 0, false);
ec = ExecutionContextFactory.createContext(prog);
double optCosts = CostEstimationWrapper.getTimeEstimate(prog, ec);
//maintain optimization summary statistics
summary.setCostsInitial(initCosts);
summary.setCostsOptimal(optCosts);
summary.setNumEnumPlans(_enumeratedPlans);
summary.setNumPrunedInvalidPlans(_prunedInvalidPlans);
summary.setNumPrunedSuboptPlans(_prunedSuboptimalPlans);
summary.setNumCompiledPlans(_compiledPlans);
summary.setNumCostedPlans(_costedPlans);
summary.setNumEnumPlanMismatch(enumPlanMismatch);
summary.setNumFinalPlanMismatch(finalPlanMismatch);
summary.setTimeOptim(time.stop());
return gdfgraph;
}
use of org.apache.sysml.runtime.controlprogram.Program in project incubator-systemml by apache.
the class CostEstimator method rGetTimeEstimate.
private double rGetTimeEstimate(ProgramBlock pb, HashMap<String, VarStats> stats, HashSet<String> memoFunc, boolean recursive) throws DMLRuntimeException {
double ret = 0;
if (pb instanceof WhileProgramBlock) {
WhileProgramBlock tmp = (WhileProgramBlock) pb;
if (recursive)
for (ProgramBlock pb2 : tmp.getChildBlocks()) ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive);
ret *= DEFAULT_NUMITER;
} else if (pb instanceof IfProgramBlock) {
IfProgramBlock tmp = (IfProgramBlock) pb;
if (recursive) {
for (ProgramBlock pb2 : tmp.getChildBlocksIfBody()) ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive);
if (tmp.getChildBlocksElseBody() != null)
for (ProgramBlock pb2 : tmp.getChildBlocksElseBody()) {
ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive);
//weighted sum
ret /= 2;
}
}
} else if (//includes ParFORProgramBlock
pb instanceof ForProgramBlock) {
ForProgramBlock tmp = (ForProgramBlock) pb;
if (recursive)
for (ProgramBlock pb2 : tmp.getChildBlocks()) ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive);
ret *= getNumIterations(stats, tmp.getIterablePredicateVars());
} else if (pb instanceof FunctionProgramBlock && //see generic
!(pb instanceof ExternalFunctionProgramBlock)) {
FunctionProgramBlock tmp = (FunctionProgramBlock) pb;
if (recursive)
for (ProgramBlock pb2 : tmp.getChildBlocks()) ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive);
} else {
ArrayList<Instruction> tmp = pb.getInstructions();
for (Instruction inst : tmp) {
if (//CP
inst instanceof CPInstruction) {
//obtain stats from createvar, cpvar, rmvar, rand
maintainCPInstVariableStatistics((CPInstruction) inst, stats);
//extract statistics (instruction-specific)
Object[] o = extractCPInstStatistics(inst, stats);
VarStats[] vs = (VarStats[]) o[0];
String[] attr = (String[]) o[1];
//if(LOG.isDebugEnabled())
// LOG.debug(inst);
//call time estimation for inst
ret += getCPInstTimeEstimate(inst, vs, attr);
if (//functions
inst instanceof FunctionCallCPInstruction) {
FunctionCallCPInstruction finst = (FunctionCallCPInstruction) inst;
String fkey = DMLProgram.constructFunctionKey(finst.getNamespace(), finst.getFunctionName());
//awareness of recursive functions, missing program
if (!memoFunc.contains(fkey) && pb.getProgram() != null) {
if (LOG.isDebugEnabled())
LOG.debug("Begin Function " + fkey);
memoFunc.add(fkey);
Program prog = pb.getProgram();
FunctionProgramBlock fpb = prog.getFunctionProgramBlock(finst.getNamespace(), finst.getFunctionName());
ret += rGetTimeEstimate(fpb, stats, memoFunc, recursive);
memoFunc.remove(fkey);
if (LOG.isDebugEnabled())
LOG.debug("End Function " + fkey);
}
}
} else if (//MR
inst instanceof MRJobInstruction) {
//obtain stats for job
maintainMRJobInstVariableStatistics(inst, stats);
//extract input statistics
Object[] o = extractMRJobInstStatistics(inst, stats);
VarStats[] vs = (VarStats[]) o[0];
if (LOG.isDebugEnabled())
LOG.debug("Begin MRJob type=" + ((MRJobInstruction) inst).getJobType());
//call time estimation for complex MR inst
ret += getMRJobInstTimeEstimate(inst, vs, null);
if (LOG.isDebugEnabled())
LOG.debug("End MRJob");
//cleanup stats for job
cleanupMRJobVariableStatistics(inst, stats);
}
}
}
return ret;
}
use of org.apache.sysml.runtime.controlprogram.Program in project incubator-systemml by apache.
the class OptimizerRuleBased method removeUnnecessaryParFor.
protected int removeUnnecessaryParFor(OptNode n) throws DMLRuntimeException {
int count = 0;
if (!n.isLeaf()) {
for (OptNode sub : n.getChilds()) {
if (sub.getNodeType() == NodeType.PARFOR && sub.getK() == 1) {
long id = sub.getID();
Object[] progobj = OptTreeConverter.getAbstractPlanMapping().getMappedProg(id);
ParForStatementBlock pfsb = (ParForStatementBlock) progobj[0];
ParForProgramBlock pfpb = (ParForProgramBlock) progobj[1];
//create for pb as replacement
Program prog = pfpb.getProgram();
ForProgramBlock fpb = ProgramConverter.createShallowCopyForProgramBlock(pfpb, prog);
//replace parfor with for, and update objectmapping
OptTreeConverter.replaceProgramBlock(n, sub, pfpb, fpb, false);
//update link to statement block
fpb.setStatementBlock(pfsb);
//update node
sub.setNodeType(NodeType.FOR);
sub.setK(1);
count++;
}
count += removeUnnecessaryParFor(sub);
}
}
return count;
}
use of org.apache.sysml.runtime.controlprogram.Program in project incubator-systemml by apache.
the class ScriptExecutorUtils method executeRuntimeProgram.
/**
* Execute the runtime program. This involves execution of the program
* blocks that make up the runtime program and may involve dynamic
* recompilation.
*
* @param se
* script executor
* @param statisticsMaxHeavyHitters
* maximum number of statistics to print
* @throws DMLRuntimeException
* if exception occurs
*/
public static void executeRuntimeProgram(ScriptExecutor se, int statisticsMaxHeavyHitters) throws DMLRuntimeException {
Program prog = se.getRuntimeProgram();
ExecutionContext ec = se.getExecutionContext();
DMLConfig config = se.getConfig();
executeRuntimeProgram(prog, ec, config, statisticsMaxHeavyHitters);
}
use of org.apache.sysml.runtime.controlprogram.Program in project incubator-systemml by apache.
the class Connection method prepareScript.
/**
* Prepares (precompiles) a script, sets input parameter values, and registers input and output variables.
*
* @param script string representing the DML or PyDML script
* @param args map of input parameters ($) and their values
* @param inputs string array of input variables to register
* @param outputs string array of output variables to register
* @param parsePyDML {@code true} if PyDML, {@code false} if DML
* @return PreparedScript object representing the precompiled script
* @throws DMLException if DMLException occurs
*/
public PreparedScript prepareScript(String script, Map<String, String> args, String[] inputs, String[] outputs, boolean parsePyDML) throws DMLException {
DMLScript.SCRIPT_TYPE = parsePyDML ? ScriptType.PYDML : ScriptType.DML;
//prepare arguments
//simplified compilation chain
Program rtprog = null;
try {
//parsing
ParserWrapper parser = ParserFactory.createParser(parsePyDML ? ScriptType.PYDML : ScriptType.DML);
DMLProgram prog = parser.parse(null, script, args);
//language validate
DMLTranslator dmlt = new DMLTranslator(prog);
dmlt.liveVariableAnalysis(prog);
dmlt.validateParseTree(prog);
//hop construct/rewrite
dmlt.constructHops(prog);
dmlt.rewriteHopsDAG(prog);
//rewrite persistent reads/writes
RewriteRemovePersistentReadWrite rewrite = new RewriteRemovePersistentReadWrite(inputs, outputs);
ProgramRewriter rewriter2 = new ProgramRewriter(rewrite);
rewriter2.rewriteProgramHopDAGs(prog);
//lop construct and runtime prog generation
dmlt.constructLops(prog);
rtprog = prog.getRuntimeProgram(_dmlconf);
//final cleanup runtime prog
JMLCUtils.cleanupRuntimeProgram(rtprog, outputs);
//System.out.println(Explain.explain(rtprog));
} catch (ParseException pe) {
// don't chain ParseException (for cleaner error output)
throw pe;
} catch (Exception ex) {
throw new DMLException(ex);
}
//return newly create precompiled script
return new PreparedScript(rtprog, inputs, outputs);
}
Aggregations