use of org.apache.sysml.parser.FunctionStatementBlock in project incubator-systemml by apache.
the class InterProceduralAnalysis method propagateStatisticsAcrossBlock.
// ///////////////////////////
// INTRA-PROCEDURE ANALYSIS
// ////
private void propagateStatisticsAcrossBlock(StatementBlock sb, LocalVariableMap callVars, FunctionCallSizeInfo fcallSizes, Set<String> fnStack) {
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
} else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
// old stats into predicate
propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
// remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, wsb);
// check and propagate stats into body
LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, wsb)) {
// second pass if required
propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
}
// remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
} else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
// old stats into predicate
propagateStatisticsAcrossPredicateDAG(isb.getPredicateHops(), callVars);
// check and propagate stats into body
LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
LocalVariableMap callVarsElse = (LocalVariableMap) callVars.clone();
for (StatementBlock sbi : istmt.getIfBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
for (StatementBlock sbi : istmt.getElseBody()) propagateStatisticsAcrossBlock(sbi, callVarsElse, fcallSizes, fnStack);
callVars = Recompiler.reconcileUpdatedCallVarsIf(oldCallVars, callVars, callVarsElse, isb);
// remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
} else if (// incl parfor
sb instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
// old stats into predicate
propagateStatisticsAcrossPredicateDAG(fsb.getFromHops(), callVars);
propagateStatisticsAcrossPredicateDAG(fsb.getToHops(), callVars);
propagateStatisticsAcrossPredicateDAG(fsb.getIncrementHops(), callVars);
// remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, fsb);
// check and propagate stats into body
LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, fsb))
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
// remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
} else // generic (last-level)
{
// remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
// old stats in, new stats out if updated
ArrayList<Hop> roots = sb.getHops();
DMLProgram prog = sb.getDMLProg();
// replace scalar reads with literals
Hop.resetVisitStatus(roots);
propagateScalarsAcrossDAG(roots, callVars);
// refresh stats across dag
Hop.resetVisitStatus(roots);
propagateStatisticsAcrossDAG(roots, callVars);
// propagate stats into function calls
Hop.resetVisitStatus(roots);
propagateStatisticsIntoFunctions(prog, roots, callVars, fcallSizes, fnStack);
}
}
use of org.apache.sysml.parser.FunctionStatementBlock in project incubator-systemml by apache.
the class InterProceduralAnalysis method isUnarySizePreservingFunction.
private boolean isUnarySizePreservingFunction(FunctionStatementBlock fsb) {
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
// check unary functions over matrices
boolean ret = (fstmt.getInputParams().size() == 1 && fstmt.getInputParams().get(0).getDataType() == DataType.MATRIX && fstmt.getOutputParams().size() == 1 && fstmt.getOutputParams().get(0).getDataType() == DataType.MATRIX);
// check size-preserving characteristic
if (ret) {
FunctionCallSizeInfo fcallSizes = new FunctionCallSizeInfo(_fgraph, false);
HashSet<String> fnStack = new HashSet<>();
LocalVariableMap callVars = new LocalVariableMap();
// populate input
MatrixObject mo = createOutputMatrix(7777, 3333, -1);
callVars.put(fstmt.getInputParams().get(0).getName(), mo);
// propagate statistics
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
// compare output
MatrixObject mo2 = (MatrixObject) callVars.get(fstmt.getOutputParams().get(0).getName());
ret &= mo.getNumRows() == mo2.getNumRows() && mo.getNumColumns() == mo2.getNumColumns();
// reset function
mo.getMatrixCharacteristics().setDimension(-1, -1);
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
}
return ret;
}
use of org.apache.sysml.parser.FunctionStatementBlock in project incubator-systemml by apache.
the class ProgramRewriter method rRewriteStatementBlock.
public ArrayList<StatementBlock> rRewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status, boolean splitDags) {
ArrayList<StatementBlock> ret = new ArrayList<>();
ret.add(sb);
// recursive invocation
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
fstmt.setBody(rRewriteStatementBlocks(fstmt.getBody(), status, splitDags));
} else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
wstmt.setBody(rRewriteStatementBlocks(wstmt.getBody(), status, splitDags));
} else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
istmt.setIfBody(rRewriteStatementBlocks(istmt.getIfBody(), status, splitDags));
istmt.setElseBody(rRewriteStatementBlocks(istmt.getElseBody(), status, splitDags));
} else if (sb instanceof ForStatementBlock) {
// incl parfor
// maintain parfor context information (e.g., for checkpointing)
boolean prestatus = status.isInParforContext();
if (sb instanceof ParForStatementBlock)
status.setInParforContext(true);
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
fstmt.setBody(rRewriteStatementBlocks(fstmt.getBody(), status, splitDags));
status.setInParforContext(prestatus);
}
// apply rewrite rules to individual statement blocks
for (StatementBlockRewriteRule r : _sbRuleSet) {
if (!splitDags && r.createsSplitDag())
continue;
ArrayList<StatementBlock> tmp = new ArrayList<>();
for (StatementBlock sbc : ret) tmp.addAll(r.rewriteStatementBlock(sbc, status));
// take over set of rewritten sbs
ret.clear();
ret.addAll(tmp);
}
return ret;
}
use of org.apache.sysml.parser.FunctionStatementBlock in project incubator-systemml by apache.
the class OptTreeConverter method rCreateAbstractOptNodes.
public static ArrayList<OptNode> rCreateAbstractOptNodes(Hop hop, LocalVariableMap vars, Set<String> memo) {
ArrayList<OptNode> ret = new ArrayList<>();
ArrayList<Hop> in = hop.getInput();
if (hop.isVisited())
return ret;
// general case
if (!(hop instanceof DataOp || hop instanceof LiteralOp || hop instanceof FunctionOp)) {
OptNode node = new OptNode(NodeType.HOP);
String opstr = hop.getOpString();
node.addParam(ParamType.OPSTRING, opstr);
// handle execution type
LopProperties.ExecType et = (hop.getExecType() != null) ? hop.getExecType() : LopProperties.ExecType.CP;
switch(et) {
case CP:
case GPU:
node.setExecType(ExecType.CP);
break;
case SPARK:
node.setExecType(ExecType.SPARK);
break;
case MR:
node.setExecType(ExecType.MR);
break;
default:
throw new DMLRuntimeException("Unsupported optnode exec type: " + et);
}
// handle degree of parallelism
if (et == LopProperties.ExecType.CP && hop instanceof MultiThreadedHop) {
MultiThreadedHop mtop = (MultiThreadedHop) hop;
node.setK(OptimizerUtils.getConstrainedNumThreads(mtop.getMaxNumThreads()));
}
// assign node to return
_hlMap.putHopMapping(hop, node);
ret.add(node);
} else // process function calls
if (hop instanceof FunctionOp && INCLUDE_FUNCTIONS) {
FunctionOp fhop = (FunctionOp) hop;
String fname = fhop.getFunctionName();
String fnspace = fhop.getFunctionNamespace();
String fKey = fhop.getFunctionKey();
Object[] prog = _hlMap.getRootProgram();
OptNode node = new OptNode(NodeType.FUNCCALL);
_hlMap.putHopMapping(fhop, node);
node.setExecType(ExecType.CP);
node.addParam(ParamType.OPSTRING, fKey);
if (!fnspace.equals(DMLProgram.INTERNAL_NAMESPACE)) {
FunctionProgramBlock fpb = ((Program) prog[1]).getFunctionProgramBlock(fnspace, fname);
FunctionStatementBlock fsb = ((DMLProgram) prog[0]).getFunctionStatementBlock(fnspace, fname);
FunctionStatement fs = (FunctionStatement) fsb.getStatement(0);
// process body; NOTE: memo prevents inclusion of functions multiple times
if (!memo.contains(fKey)) {
memo.add(fKey);
int len = fs.getBody().size();
for (int i = 0; i < fpb.getChildBlocks().size() && i < len; i++) {
ProgramBlock lpb = fpb.getChildBlocks().get(i);
StatementBlock lsb = fs.getBody().get(i);
node.addChild(rCreateAbstractOptNode(lsb, lpb, vars, false, memo));
}
memo.remove(fKey);
} else
node.addParam(ParamType.RECURSIVE_CALL, "true");
}
ret.add(node);
}
if (in != null)
for (Hop hin : in) if (// no need for opt nodes
!(hin instanceof DataOp || hin instanceof LiteralOp))
ret.addAll(rCreateAbstractOptNodes(hin, vars, memo));
hop.setVisited();
return ret;
}
use of org.apache.sysml.parser.FunctionStatementBlock in project incubator-systemml by apache.
the class OptTreePlanChecker method checkProgramCorrectness.
public static void checkProgramCorrectness(ProgramBlock pb, StatementBlock sb, Set<String> fnStack) {
Program prog = pb.getProgram();
DMLProgram dprog = sb.getDMLProg();
if (pb instanceof FunctionProgramBlock && sb instanceof FunctionStatementBlock) {
FunctionProgramBlock fpb = (FunctionProgramBlock) pb;
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (int i = 0; i < fpb.getChildBlocks().size(); i++) {
ProgramBlock pbc = fpb.getChildBlocks().get(i);
StatementBlock sbc = fstmt.getBody().get(i);
checkProgramCorrectness(pbc, sbc, fnStack);
}
// checkLinksProgramStatementBlock(fpb, fsb);
} else if (pb instanceof WhileProgramBlock && sb instanceof WhileStatementBlock) {
WhileProgramBlock wpb = (WhileProgramBlock) pb;
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
checkHopDagCorrectness(prog, dprog, wsb.getPredicateHops(), wpb.getPredicate(), fnStack);
for (int i = 0; i < wpb.getChildBlocks().size(); i++) {
ProgramBlock pbc = wpb.getChildBlocks().get(i);
StatementBlock sbc = wstmt.getBody().get(i);
checkProgramCorrectness(pbc, sbc, fnStack);
}
checkLinksProgramStatementBlock(wpb, wsb);
} else if (pb instanceof IfProgramBlock && sb instanceof IfStatementBlock) {
IfProgramBlock ipb = (IfProgramBlock) pb;
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
checkHopDagCorrectness(prog, dprog, isb.getPredicateHops(), ipb.getPredicate(), fnStack);
for (int i = 0; i < ipb.getChildBlocksIfBody().size(); i++) {
ProgramBlock pbc = ipb.getChildBlocksIfBody().get(i);
StatementBlock sbc = istmt.getIfBody().get(i);
checkProgramCorrectness(pbc, sbc, fnStack);
}
for (int i = 0; i < ipb.getChildBlocksElseBody().size(); i++) {
ProgramBlock pbc = ipb.getChildBlocksElseBody().get(i);
StatementBlock sbc = istmt.getElseBody().get(i);
checkProgramCorrectness(pbc, sbc, fnStack);
}
checkLinksProgramStatementBlock(ipb, isb);
} else if (// incl parfor
pb instanceof ForProgramBlock && sb instanceof ForStatementBlock) {
ForProgramBlock fpb = (ForProgramBlock) pb;
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) sb.getStatement(0);
checkHopDagCorrectness(prog, dprog, fsb.getFromHops(), fpb.getFromInstructions(), fnStack);
checkHopDagCorrectness(prog, dprog, fsb.getToHops(), fpb.getToInstructions(), fnStack);
checkHopDagCorrectness(prog, dprog, fsb.getIncrementHops(), fpb.getIncrementInstructions(), fnStack);
for (int i = 0; i < fpb.getChildBlocks().size(); i++) {
ProgramBlock pbc = fpb.getChildBlocks().get(i);
StatementBlock sbc = fstmt.getBody().get(i);
checkProgramCorrectness(pbc, sbc, fnStack);
}
checkLinksProgramStatementBlock(fpb, fsb);
} else {
checkHopDagCorrectness(prog, dprog, sb.getHops(), pb.getInstructions(), fnStack);
// checkLinksProgramStatementBlock(pb, sb);
}
}
Aggregations