use of org.apache.sysml.runtime.controlprogram.FunctionProgramBlock in project incubator-systemml by apache.
the class OptTreePlanChecker method checkProgramCorrectness.
public static void checkProgramCorrectness(ProgramBlock pb, StatementBlock sb, Set<String> fnStack) {
Program prog = pb.getProgram();
DMLProgram dprog = sb.getDMLProg();
if (pb instanceof FunctionProgramBlock && sb instanceof FunctionStatementBlock) {
FunctionProgramBlock fpb = (FunctionProgramBlock) pb;
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (int i = 0; i < fpb.getChildBlocks().size(); i++) {
ProgramBlock pbc = fpb.getChildBlocks().get(i);
StatementBlock sbc = fstmt.getBody().get(i);
checkProgramCorrectness(pbc, sbc, fnStack);
}
// checkLinksProgramStatementBlock(fpb, fsb);
} else if (pb instanceof WhileProgramBlock && sb instanceof WhileStatementBlock) {
WhileProgramBlock wpb = (WhileProgramBlock) pb;
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
checkHopDagCorrectness(prog, dprog, wsb.getPredicateHops(), wpb.getPredicate(), fnStack);
for (int i = 0; i < wpb.getChildBlocks().size(); i++) {
ProgramBlock pbc = wpb.getChildBlocks().get(i);
StatementBlock sbc = wstmt.getBody().get(i);
checkProgramCorrectness(pbc, sbc, fnStack);
}
checkLinksProgramStatementBlock(wpb, wsb);
} else if (pb instanceof IfProgramBlock && sb instanceof IfStatementBlock) {
IfProgramBlock ipb = (IfProgramBlock) pb;
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
checkHopDagCorrectness(prog, dprog, isb.getPredicateHops(), ipb.getPredicate(), fnStack);
for (int i = 0; i < ipb.getChildBlocksIfBody().size(); i++) {
ProgramBlock pbc = ipb.getChildBlocksIfBody().get(i);
StatementBlock sbc = istmt.getIfBody().get(i);
checkProgramCorrectness(pbc, sbc, fnStack);
}
for (int i = 0; i < ipb.getChildBlocksElseBody().size(); i++) {
ProgramBlock pbc = ipb.getChildBlocksElseBody().get(i);
StatementBlock sbc = istmt.getElseBody().get(i);
checkProgramCorrectness(pbc, sbc, fnStack);
}
checkLinksProgramStatementBlock(ipb, isb);
} else if (// incl parfor
pb instanceof ForProgramBlock && sb instanceof ForStatementBlock) {
ForProgramBlock fpb = (ForProgramBlock) pb;
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) sb.getStatement(0);
checkHopDagCorrectness(prog, dprog, fsb.getFromHops(), fpb.getFromInstructions(), fnStack);
checkHopDagCorrectness(prog, dprog, fsb.getToHops(), fpb.getToInstructions(), fnStack);
checkHopDagCorrectness(prog, dprog, fsb.getIncrementHops(), fpb.getIncrementInstructions(), fnStack);
for (int i = 0; i < fpb.getChildBlocks().size(); i++) {
ProgramBlock pbc = fpb.getChildBlocks().get(i);
StatementBlock sbc = fstmt.getBody().get(i);
checkProgramCorrectness(pbc, sbc, fnStack);
}
checkLinksProgramStatementBlock(fpb, fsb);
} else {
checkHopDagCorrectness(prog, dprog, sb.getHops(), pb.getInstructions(), fnStack);
// checkLinksProgramStatementBlock(pb, sb);
}
}
use of org.apache.sysml.runtime.controlprogram.FunctionProgramBlock in project incubator-systemml by apache.
the class OptimizationWrapper method optimize.
@SuppressWarnings("unused")
private static void optimize(POptMode otype, int ck, double cm, ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, boolean monitor) {
Timing time = new Timing(true);
// maintain statistics
if (DMLScript.STATISTICS)
Statistics.incrementParForOptimCount();
// create specified optimizer
Optimizer opt = createOptimizer(otype);
CostModelType cmtype = opt.getCostModelType();
LOG.trace("ParFOR Opt: Created optimizer (" + otype + "," + opt.getPlanInputType() + "," + opt.getCostModelType());
OptTree tree = null;
// recompile parfor body
if (ConfigurationManager.isDynamicRecompilation()) {
ForStatement fs = (ForStatement) sb.getStatement(0);
// debug output before recompilation
if (LOG.isDebugEnabled()) {
try {
tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec);
LOG.debug("ParFOR Opt: Input plan (before recompilation):\n" + tree.explain(false));
OptTreeConverter.clear();
} catch (Exception ex) {
throw new DMLRuntimeException("Unable to create opt tree.", ex);
}
}
// separate propagation required because recompile in-place without literal replacement)
try {
LocalVariableMap constVars = ProgramRecompiler.getReusableScalarVariables(sb.getDMLProg(), sb, ec.getVariables());
ProgramRecompiler.replaceConstantScalarVariables(sb, constVars);
} catch (Exception ex) {
throw new DMLRuntimeException(ex);
}
// program rewrites (e.g., constant folding, branch removal) according to replaced literals
try {
ProgramRewriter rewriter = createProgramRewriterWithRuleSets();
ProgramRewriteStatus state = new ProgramRewriteStatus();
rewriter.rRewriteStatementBlockHopDAGs(sb, state);
fs.setBody(rewriter.rRewriteStatementBlocks(fs.getBody(), state, true));
if (state.getRemovedBranches()) {
LOG.debug("ParFOR Opt: Removed branches during program rewrites, rebuilding runtime program");
pb.setChildBlocks(ProgramRecompiler.generatePartitialRuntimeProgram(pb.getProgram(), fs.getBody()));
}
} catch (Exception ex) {
throw new DMLRuntimeException(ex);
}
// recompilation of parfor body and called functions (if safe)
try {
// core parfor body recompilation (based on symbol table entries)
// * clone of variables in order to allow for statistics propagation across DAGs
// (tid=0, because deep copies created after opt)
LocalVariableMap tmp = (LocalVariableMap) ec.getVariables().clone();
ResetType reset = ConfigurationManager.isCodegenEnabled() ? ResetType.RESET_KNOWN_DIMS : ResetType.RESET;
Recompiler.recompileProgramBlockHierarchy(pb.getChildBlocks(), tmp, 0, reset);
// inter-procedural optimization (based on previous recompilation)
if (pb.hasFunctions()) {
InterProceduralAnalysis ipa = new InterProceduralAnalysis(sb);
Set<String> fcand = ipa.analyzeSubProgram();
if (!fcand.isEmpty()) {
// regenerate runtime program of modified functions
for (String func : fcand) {
String[] funcparts = DMLProgram.splitFunctionKey(func);
FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(funcparts[0], funcparts[1]);
// reset recompilation flags according to recompileOnce because it is only safe if function is recompileOnce
// because then recompiled for every execution (otherwise potential issues if func also called outside parfor)
ResetType reset2 = fpb.isRecompileOnce() ? reset : ResetType.NO_RESET;
Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new LocalVariableMap(), 0, reset2);
}
}
}
} catch (Exception ex) {
throw new DMLRuntimeException(ex);
}
}
// create opt tree (before optimization)
try {
tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec);
LOG.debug("ParFOR Opt: Input plan (before optimization):\n" + tree.explain(false));
} catch (Exception ex) {
throw new DMLRuntimeException("Unable to create opt tree.", ex);
}
// create cost estimator
CostEstimator est = createCostEstimator(cmtype, ec.getVariables());
LOG.trace("ParFOR Opt: Created cost estimator (" + cmtype + ")");
// core optimize
opt.optimize(sb, pb, tree, est, ec);
LOG.debug("ParFOR Opt: Optimized plan (after optimization): \n" + tree.explain(false));
// assert plan correctness
if (CHECK_PLAN_CORRECTNESS && LOG.isDebugEnabled()) {
try {
OptTreePlanChecker.checkProgramCorrectness(pb, sb, new HashSet<String>());
LOG.debug("ParFOR Opt: Checked plan and program correctness.");
} catch (Exception ex) {
throw new DMLRuntimeException("Failed to check program correctness.", ex);
}
}
long ltime = (long) time.stop();
LOG.trace("ParFOR Opt: Optimized plan in " + ltime + "ms.");
if (DMLScript.STATISTICS)
Statistics.incrementParForOptimTime(ltime);
// cleanup phase
OptTreeConverter.clear();
// monitor stats
if (monitor) {
StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_OPTIMIZER, otype.ordinal());
StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_NUMTPLANS, opt.getNumTotalPlans());
StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_NUMEPLANS, opt.getNumEvaluatedPlans());
}
}
use of org.apache.sysml.runtime.controlprogram.FunctionProgramBlock in project incubator-systemml by apache.
the class OptimizerRuleBased method rFindAndUnfoldRecursiveFunction.
protected void rFindAndUnfoldRecursiveFunction(OptNode n, ParForProgramBlock parfor, HashSet<ParForProgramBlock> recPBs, LocalVariableMap vars) {
// unfold if found
if (n.getNodeType() == NodeType.FUNCCALL && n.isRecursive()) {
boolean exists = rContainsNode(n, parfor);
if (exists) {
String fnameKey = n.getParam(ParamType.OPSTRING);
String[] names = fnameKey.split(Program.KEY_DELIM);
String fnamespace = names[0];
String fname = names[1];
String fnameNew = FUNCTION_UNFOLD_NAMEPREFIX + fname;
// unfold function
FunctionOp fop = (FunctionOp) OptTreeConverter.getAbstractPlanMapping().getMappedHop(n.getID());
Program prog = parfor.getProgram();
DMLProgram dmlprog = parfor.getStatementBlock().getDMLProg();
FunctionProgramBlock fpb = prog.getFunctionProgramBlock(fnamespace, fname);
FunctionProgramBlock copyfpb = ProgramConverter.createDeepCopyFunctionProgramBlock(fpb, new HashSet<String>(), new HashSet<String>());
prog.addFunctionProgramBlock(fnamespace, fnameNew, copyfpb);
dmlprog.addFunctionStatementBlock(fnamespace, fnameNew, (FunctionStatementBlock) copyfpb.getStatementBlock());
// replace function names in old subtree (link to new function)
rReplaceFunctionNames(n, fname, fnameNew);
// recreate sub opttree
String fnameNewKey = fnamespace + Program.KEY_DELIM + fnameNew;
OptNode nNew = new OptNode(NodeType.FUNCCALL);
OptTreeConverter.getAbstractPlanMapping().putHopMapping(fop, nNew);
nNew.setExecType(ExecType.CP);
nNew.addParam(ParamType.OPSTRING, fnameNewKey);
long parentID = OptTreeConverter.getAbstractPlanMapping().getMappedParentID(n.getID());
OptTreeConverter.getAbstractPlanMapping().getOptNode(parentID).exchangeChild(n, nNew);
HashSet<String> memo = new HashSet<>();
// required if functionop not shared (because not replaced yet)
memo.add(fnameKey);
// requied if functionop shared (indirectly replaced)
memo.add(fnameNewKey);
for (int i = 0; i < copyfpb.getChildBlocks().size(); /*&& i<len*/
i++) {
ProgramBlock lpb = copyfpb.getChildBlocks().get(i);
StatementBlock lsb = lpb.getStatementBlock();
nNew.addChild(OptTreeConverter.rCreateAbstractOptNode(lsb, lpb, vars, false, memo));
}
// compute delta for recPB set (use for removing parfor)
recPBs.removeAll(rGetAllParForPBs(n, new HashSet<ParForProgramBlock>()));
recPBs.addAll(rGetAllParForPBs(nNew, new HashSet<ParForProgramBlock>()));
// replace function names in new subtree (recursive link to new function)
rReplaceFunctionNames(nNew, fname, fnameNew);
}
return;
}
// recursive invocation (only for non-recursive functions)
if (!n.isLeaf())
for (OptNode c : n.getChilds()) rFindAndUnfoldRecursiveFunction(c, parfor, recPBs, vars);
}
use of org.apache.sysml.runtime.controlprogram.FunctionProgramBlock in project incubator-systemml by apache.
the class FunctionCallCPInstruction method processInstruction.
@Override
public void processInstruction(ExecutionContext ec) {
if (LOG.isTraceEnabled()) {
LOG.trace("Executing instruction : " + this.toString());
}
// get the function program block (stored in the Program object)
FunctionProgramBlock fpb = ec.getProgram().getFunctionProgramBlock(_namespace, _functionName);
// sanity check number of function parameters
if (_boundInputs.length < fpb.getInputParams().size()) {
throw new DMLRuntimeException("Number of bound input parameters does not match the function signature " + "(" + _boundInputs.length + ", but " + fpb.getInputParams().size() + " expected)");
}
// create bindings to formal parameters for given function call
// These are the bindings passed to the FunctionProgramBlock for function execution
LocalVariableMap functionVariables = new LocalVariableMap();
for (int i = 0; i < fpb.getInputParams().size(); i++) {
// error handling non-existing variables
CPOperand input = _boundInputs[i];
if (!input.isLiteral() && !ec.containsVariable(input.getName())) {
throw new DMLRuntimeException("Input variable '" + input.getName() + "' not existing on call of " + DMLProgram.constructFunctionKey(_namespace, _functionName) + " (line " + getLineNum() + ").");
}
// get input matrix/frame/scalar
DataIdentifier currFormalParam = fpb.getInputParams().get(i);
Data value = ec.getVariable(input);
// graceful value type conversion for scalar inputs with wrong type
if (value.getDataType() == DataType.SCALAR && value.getValueType() != currFormalParam.getValueType()) {
value = ScalarObjectFactory.createScalarObject(currFormalParam.getValueType(), (ScalarObject) value);
}
// set input parameter
functionVariables.put(currFormalParam.getName(), value);
}
// Pin the input variables so that they do not get deleted
// from pb's symbol table at the end of execution of function
boolean[] pinStatus = ec.pinVariables(_boundInputNames);
// Create a symbol table under a new execution context for the function invocation,
// and copy the function arguments into the created table.
ExecutionContext fn_ec = ExecutionContextFactory.createContext(false, ec.getProgram());
if (DMLScript.USE_ACCELERATOR) {
fn_ec.setGPUContexts(ec.getGPUContexts());
fn_ec.getGPUContext(0).initializeThread();
}
fn_ec.setVariables(functionVariables);
// execute the function block
try {
fpb._functionName = this._functionName;
fpb._namespace = this._namespace;
fpb.execute(fn_ec);
} catch (DMLScriptException e) {
throw e;
} catch (Exception e) {
String fname = DMLProgram.constructFunctionKey(_namespace, _functionName);
throw new DMLRuntimeException("error executing function " + fname, e);
}
// cleanup all returned variables w/o binding
HashSet<String> expectRetVars = new HashSet<>();
for (DataIdentifier di : fpb.getOutputParams()) expectRetVars.add(di.getName());
LocalVariableMap retVars = fn_ec.getVariables();
for (Entry<String, Data> var : retVars.entrySet()) {
if (expectRetVars.contains(var.getKey()))
continue;
// cleanup unexpected return values to avoid leaks
if (var.getValue() instanceof CacheableData)
fn_ec.cleanupCacheableData((CacheableData<?>) var.getValue());
}
// Unpin the pinned variables
ec.unpinVariables(_boundInputNames, pinStatus);
// add the updated binding for each return variable to the variables in original symbol table
for (int i = 0; i < fpb.getOutputParams().size(); i++) {
String boundVarName = _boundOutputNames.get(i);
Data boundValue = retVars.get(fpb.getOutputParams().get(i).getName());
if (boundValue == null)
throw new DMLRuntimeException(boundVarName + " was not assigned a return value");
// cleanup existing data bound to output variable name
Data exdata = ec.removeVariable(boundVarName);
if (exdata != null && exdata instanceof CacheableData && exdata != boundValue) {
ec.cleanupCacheableData((CacheableData<?>) exdata);
}
// add/replace data in symbol table
ec.setVariable(boundVarName, boundValue);
}
}
use of org.apache.sysml.runtime.controlprogram.FunctionProgramBlock in project incubator-systemml by apache.
the class Explain method countCompiledInstructions.
/**
* Recursively counts the number of compiled MRJob instructions in the
* given runtime program block.
*
* @param pb program block
* @param counts explain countst
* @param MR if true, count Hadoop instructions
* @param CP if true, count CP instructions
* @param SP if true, count Spark instructions
*/
private static void countCompiledInstructions(ProgramBlock pb, ExplainCounts counts, boolean MR, boolean CP, boolean SP) {
if (pb instanceof WhileProgramBlock) {
WhileProgramBlock tmp = (WhileProgramBlock) pb;
countCompiledInstructions(tmp.getPredicate(), counts, MR, CP, SP);
for (ProgramBlock pb2 : tmp.getChildBlocks()) countCompiledInstructions(pb2, counts, MR, CP, SP);
} else if (pb instanceof IfProgramBlock) {
IfProgramBlock tmp = (IfProgramBlock) pb;
countCompiledInstructions(tmp.getPredicate(), counts, MR, CP, SP);
for (ProgramBlock pb2 : tmp.getChildBlocksIfBody()) countCompiledInstructions(pb2, counts, MR, CP, SP);
for (ProgramBlock pb2 : tmp.getChildBlocksElseBody()) countCompiledInstructions(pb2, counts, MR, CP, SP);
} else if (// includes ParFORProgramBlock
pb instanceof ForProgramBlock) {
ForProgramBlock tmp = (ForProgramBlock) pb;
countCompiledInstructions(tmp.getFromInstructions(), counts, MR, CP, SP);
countCompiledInstructions(tmp.getToInstructions(), counts, MR, CP, SP);
countCompiledInstructions(tmp.getIncrementInstructions(), counts, MR, CP, SP);
for (ProgramBlock pb2 : tmp.getChildBlocks()) countCompiledInstructions(pb2, counts, MR, CP, SP);
// additional parfor jobs counted during runtime
} else if (// includes ExternalFunctionProgramBlock and ExternalFunctionProgramBlockCP
pb instanceof FunctionProgramBlock) {
FunctionProgramBlock fpb = (FunctionProgramBlock) pb;
for (ProgramBlock pb2 : fpb.getChildBlocks()) countCompiledInstructions(pb2, counts, MR, CP, SP);
} else {
countCompiledInstructions(pb.getInstructions(), counts, MR, CP, SP);
}
}
Aggregations