use of org.apache.sysml.parser.FunctionStatementBlock in project incubator-systemml by apache.
the class ProgramRewriter method rewriteProgramHopDAGs.
public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp) throws LanguageException, HopsException {
ProgramRewriteStatus state = new ProgramRewriteStatus();
// for each namespace, handle function statement blocks
for (String namespaceKey : dmlp.getNamespaces().keySet()) for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey, fname);
rewriteStatementBlockHopDAGs(fsblock, state);
rewriteStatementBlock(fsblock, state);
}
// handle regular statement blocks in "main" method
for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) {
StatementBlock current = dmlp.getStatementBlock(i);
rewriteStatementBlockHopDAGs(current, state);
}
dmlp.setStatementBlocks(rewriteStatementBlocks(dmlp.getStatementBlocks(), state));
return state;
}
use of org.apache.sysml.parser.FunctionStatementBlock in project incubator-systemml by apache.
the class ProgramRewriter method rewriteStatementBlock.
private ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status) throws HopsException {
ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>();
ret.add(sb);
//recursive invocation
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
fstmt.setBody(rewriteStatementBlocks(fstmt.getBody(), status));
} else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
wstmt.setBody(rewriteStatementBlocks(wstmt.getBody(), status));
} else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
istmt.setIfBody(rewriteStatementBlocks(istmt.getIfBody(), status));
istmt.setElseBody(rewriteStatementBlocks(istmt.getElseBody(), status));
} else if (//incl parfor
sb instanceof ForStatementBlock) {
//maintain parfor context information (e.g., for checkpointing)
boolean prestatus = status.isInParforContext();
if (sb instanceof ParForStatementBlock)
status.setInParforContext(true);
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
fstmt.setBody(rewriteStatementBlocks(fstmt.getBody(), status));
status.setInParforContext(prestatus);
}
//apply rewrite rules
for (StatementBlockRewriteRule r : _sbRuleSet) {
ArrayList<StatementBlock> tmp = new ArrayList<StatementBlock>();
for (StatementBlock sbc : ret) tmp.addAll(r.rewriteStatementBlock(sbc, status));
//take over set of rewritten sbs
ret.clear();
ret.addAll(tmp);
}
return ret;
}
use of org.apache.sysml.parser.FunctionStatementBlock in project incubator-systemml by apache.
the class InterProceduralAnalysis method getFunctionCandidatesForStatisticPropagation.
/////////////////////////////
// GET FUNCTION CANDIDATES
//////
private void getFunctionCandidatesForStatisticPropagation(StatementBlock sb, Map<String, Integer> fcandCounts, Map<String, FunctionOp> fcandHops) throws HopsException, ParseException {
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (StatementBlock sbi : fstmt.getBody()) getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops);
} else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
for (StatementBlock sbi : wstmt.getBody()) getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops);
} else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
for (StatementBlock sbi : istmt.getIfBody()) getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops);
for (StatementBlock sbi : istmt.getElseBody()) getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops);
} else if (//incl parfor
sb instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
for (StatementBlock sbi : fstmt.getBody()) getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops);
} else //generic (last-level)
{
ArrayList<Hop> roots = sb.get_hops();
if (//empty statement blocks
roots != null)
for (Hop root : roots) getFunctionCandidatesForStatisticPropagation(sb.getDMLProg(), root, fcandCounts, fcandHops);
}
}
use of org.apache.sysml.parser.FunctionStatementBlock in project incubator-systemml by apache.
the class OptimizerRuleBased method rFindAndUnfoldRecursiveFunction.
protected void rFindAndUnfoldRecursiveFunction(OptNode n, ParForProgramBlock parfor, HashSet<ParForProgramBlock> recPBs, LocalVariableMap vars) throws DMLRuntimeException, HopsException, LanguageException {
//unfold if found
if (n.getNodeType() == NodeType.FUNCCALL && n.isRecursive()) {
boolean exists = rContainsNode(n, parfor);
if (exists) {
String fnameKey = n.getParam(ParamType.OPSTRING);
String[] names = fnameKey.split(Program.KEY_DELIM);
String fnamespace = names[0];
String fname = names[1];
String fnameNew = FUNCTION_UNFOLD_NAMEPREFIX + fname;
//unfold function
FunctionOp fop = (FunctionOp) OptTreeConverter.getAbstractPlanMapping().getMappedHop(n.getID());
Program prog = parfor.getProgram();
DMLProgram dmlprog = parfor.getStatementBlock().getDMLProg();
FunctionProgramBlock fpb = prog.getFunctionProgramBlock(fnamespace, fname);
FunctionProgramBlock copyfpb = ProgramConverter.createDeepCopyFunctionProgramBlock(fpb, new HashSet<String>(), new HashSet<String>());
prog.addFunctionProgramBlock(fnamespace, fnameNew, copyfpb);
dmlprog.addFunctionStatementBlock(fnamespace, fnameNew, (FunctionStatementBlock) copyfpb.getStatementBlock());
//replace function names in old subtree (link to new function)
rReplaceFunctionNames(n, fname, fnameNew);
//recreate sub opttree
String fnameNewKey = fnamespace + Program.KEY_DELIM + fnameNew;
OptNode nNew = new OptNode(NodeType.FUNCCALL);
OptTreeConverter.getAbstractPlanMapping().putHopMapping(fop, nNew);
nNew.setExecType(ExecType.CP);
nNew.addParam(ParamType.OPSTRING, fnameNewKey);
long parentID = OptTreeConverter.getAbstractPlanMapping().getMappedParentID(n.getID());
OptTreeConverter.getAbstractPlanMapping().getOptNode(parentID).exchangeChild(n, nNew);
HashSet<String> memo = new HashSet<String>();
//required if functionop not shared (because not replaced yet)
memo.add(fnameKey);
//requied if functionop shared (indirectly replaced)
memo.add(fnameNewKey);
for (int i = 0; i < copyfpb.getChildBlocks().size(); /*&& i<len*/
i++) {
ProgramBlock lpb = copyfpb.getChildBlocks().get(i);
StatementBlock lsb = lpb.getStatementBlock();
nNew.addChild(OptTreeConverter.rCreateAbstractOptNode(lsb, lpb, vars, false, memo));
}
//compute delta for recPB set (use for removing parfor)
recPBs.removeAll(rGetAllParForPBs(n, new HashSet<ParForProgramBlock>()));
recPBs.addAll(rGetAllParForPBs(nNew, new HashSet<ParForProgramBlock>()));
//replace function names in new subtree (recursive link to new function)
rReplaceFunctionNames(nNew, fname, fnameNew);
}
return;
}
//recursive invocation (only for non-recursive functions)
if (!n.isLeaf())
for (OptNode c : n.getChilds()) rFindAndUnfoldRecursiveFunction(c, parfor, recPBs, vars);
}
use of org.apache.sysml.parser.FunctionStatementBlock in project incubator-systemml by apache.
the class OptTreeConverter method rCreateAbstractOptNodes.
public static ArrayList<OptNode> rCreateAbstractOptNodes(Hop hop, LocalVariableMap vars, Set<String> memo) throws DMLRuntimeException, HopsException {
ArrayList<OptNode> ret = new ArrayList<OptNode>();
ArrayList<Hop> in = hop.getInput();
if (hop.isVisited())
return ret;
//general case
if (!(hop instanceof DataOp || hop instanceof LiteralOp || hop instanceof FunctionOp)) {
OptNode node = new OptNode(NodeType.HOP);
String opstr = hop.getOpString();
node.addParam(ParamType.OPSTRING, opstr);
//handle execution type
LopProperties.ExecType et = (hop.getExecType() != null) ? hop.getExecType() : LopProperties.ExecType.CP;
switch(et) {
case CP:
case GPU:
node.setExecType(ExecType.CP);
break;
case SPARK:
node.setExecType(ExecType.SPARK);
break;
case MR:
node.setExecType(ExecType.MR);
break;
default:
throw new DMLRuntimeException("Unsupported optnode exec type: " + et);
}
//handle degree of parallelism
if (et == LopProperties.ExecType.CP && hop instanceof MultiThreadedHop) {
MultiThreadedHop mtop = (MultiThreadedHop) hop;
node.setK(OptimizerUtils.getConstrainedNumThreads(mtop.getMaxNumThreads()));
}
//assign node to return
_hlMap.putHopMapping(hop, node);
ret.add(node);
} else //process function calls
if (hop instanceof FunctionOp && INCLUDE_FUNCTIONS) {
FunctionOp fhop = (FunctionOp) hop;
String fname = fhop.getFunctionName();
String fnspace = fhop.getFunctionNamespace();
String fKey = DMLProgram.constructFunctionKey(fnspace, fname);
Object[] prog = _hlMap.getRootProgram();
OptNode node = new OptNode(NodeType.FUNCCALL);
_hlMap.putHopMapping(fhop, node);
node.setExecType(ExecType.CP);
node.addParam(ParamType.OPSTRING, fKey);
if (!fnspace.equals(DMLProgram.INTERNAL_NAMESPACE)) {
FunctionProgramBlock fpb = ((Program) prog[1]).getFunctionProgramBlock(fnspace, fname);
FunctionStatementBlock fsb = ((DMLProgram) prog[0]).getFunctionStatementBlock(fnspace, fname);
FunctionStatement fs = (FunctionStatement) fsb.getStatement(0);
//process body; NOTE: memo prevents inclusion of functions multiple times
if (!memo.contains(fKey)) {
memo.add(fKey);
int len = fs.getBody().size();
for (int i = 0; i < fpb.getChildBlocks().size() && i < len; i++) {
ProgramBlock lpb = fpb.getChildBlocks().get(i);
StatementBlock lsb = fs.getBody().get(i);
node.addChild(rCreateAbstractOptNode(lsb, lpb, vars, false, memo));
}
memo.remove(fKey);
} else
node.addParam(ParamType.RECURSIVE_CALL, "true");
}
ret.add(node);
}
if (in != null)
for (Hop hin : in) if (//no need for opt nodes
!(hin instanceof DataOp || hin instanceof LiteralOp))
ret.addAll(rCreateAbstractOptNodes(hin, vars, memo));
hop.setVisited();
return ret;
}
Aggregations