use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.
the class Recompiler method rRecompileProgramBlock2Forced.
private static void rRecompileProgramBlock2Forced(ProgramBlock pb, long tid, HashSet<String> fnStack, ExecType et) {
if (pb instanceof WhileProgramBlock) {
WhileProgramBlock pbTmp = (WhileProgramBlock) pb;
WhileStatementBlock sbTmp = (WhileStatementBlock) pbTmp.getStatementBlock();
// recompile predicate
if (sbTmp != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getPredicate(), true, true)))
pbTmp.setPredicate(Recompiler.recompileHopsDag2Forced(sbTmp.getPredicateHops(), tid, et));
// recompile body
for (ProgramBlock pb2 : pbTmp.getChildBlocks()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
} else if (pb instanceof IfProgramBlock) {
IfProgramBlock pbTmp = (IfProgramBlock) pb;
IfStatementBlock sbTmp = (IfStatementBlock) pbTmp.getStatementBlock();
// recompile predicate
if (sbTmp != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getPredicate(), true, true)))
pbTmp.setPredicate(Recompiler.recompileHopsDag2Forced(sbTmp.getPredicateHops(), tid, et));
// recompile body
for (ProgramBlock pb2 : pbTmp.getChildBlocksIfBody()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
for (ProgramBlock pb2 : pbTmp.getChildBlocksElseBody()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
} else if (// includes ParFORProgramBlock
pb instanceof ForProgramBlock) {
ForProgramBlock pbTmp = (ForProgramBlock) pb;
ForStatementBlock sbTmp = (ForStatementBlock) pbTmp.getStatementBlock();
// recompile predicate
if (sbTmp != null && sbTmp.getFromHops() != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getFromInstructions(), true, true)))
pbTmp.setFromInstructions(Recompiler.recompileHopsDag2Forced(sbTmp.getFromHops(), tid, et));
if (sbTmp != null && sbTmp.getToHops() != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getToInstructions(), true, true)))
pbTmp.setToInstructions(Recompiler.recompileHopsDag2Forced(sbTmp.getToHops(), tid, et));
if (sbTmp != null && sbTmp.getIncrementHops() != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getIncrementInstructions(), true, true)))
pbTmp.setIncrementInstructions(Recompiler.recompileHopsDag2Forced(sbTmp.getIncrementHops(), tid, et));
// recompile body
for (ProgramBlock pb2 : pbTmp.getChildBlocks()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
} else if (// includes ExternalFunctionProgramBlock and ExternalFunctionProgramBlockCP
pb instanceof FunctionProgramBlock) {
FunctionProgramBlock tmp = (FunctionProgramBlock) pb;
for (ProgramBlock pb2 : tmp.getChildBlocks()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
} else {
StatementBlock sb = pb.getStatementBlock();
// would be invalid with permutation matrix mult across multiple dags)
if (sb != null) {
ArrayList<Instruction> tmp = pb.getInstructions();
tmp = Recompiler.recompileHopsDag2Forced(sb, sb.getHops(), tid, et);
pb.setInstructions(tmp);
}
// recompile functions
if (OptTreeConverter.containsFunctionCallInstruction(pb)) {
ArrayList<Instruction> tmp = pb.getInstructions();
for (Instruction inst : tmp) if (inst instanceof FunctionCallCPInstruction) {
FunctionCallCPInstruction func = (FunctionCallCPInstruction) inst;
String fname = func.getFunctionName();
String fnamespace = func.getNamespace();
String fKey = DMLProgram.constructFunctionKey(fnamespace, fname);
if (// memoization for multiple calls, recursion
!fnStack.contains(fKey)) {
fnStack.add(fKey);
FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(fnamespace, fname);
// recompile chains of functions
rRecompileProgramBlock2Forced(fpb, tid, fnStack, et);
}
}
}
}
}
use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.
the class Recompiler method rRecompileProgramBlock.
// ////////////////////////////
// private helper functions //
// ////////////////////////////
private static void rRecompileProgramBlock(ProgramBlock pb, LocalVariableMap vars, RecompileStatus status, long tid, ResetType resetRecompile) {
if (pb instanceof WhileProgramBlock) {
WhileProgramBlock wpb = (WhileProgramBlock) pb;
WhileStatementBlock wsb = (WhileStatementBlock) wpb.getStatementBlock();
// recompile predicate
recompileWhilePredicate(wpb, wsb, vars, status, tid, resetRecompile);
// remove updated scalars because in loop
removeUpdatedScalars(vars, wsb);
// copy vars for later compare
LocalVariableMap oldVars = (LocalVariableMap) vars.clone();
RecompileStatus oldStatus = (RecompileStatus) status.clone();
for (ProgramBlock pb2 : wpb.getChildBlocks()) rRecompileProgramBlock(pb2, vars, status, tid, resetRecompile);
if (reconcileUpdatedCallVarsLoops(oldVars, vars, wsb) | reconcileUpdatedCallVarsLoops(oldStatus, status, wsb)) {
// second pass with unknowns if required
recompileWhilePredicate(wpb, wsb, vars, status, tid, resetRecompile);
for (ProgramBlock pb2 : wpb.getChildBlocks()) rRecompileProgramBlock(pb2, vars, status, tid, resetRecompile);
}
removeUpdatedScalars(vars, wsb);
} else if (pb instanceof IfProgramBlock) {
IfProgramBlock ipb = (IfProgramBlock) pb;
IfStatementBlock isb = (IfStatementBlock) ipb.getStatementBlock();
// recompile predicate
recompileIfPredicate(ipb, isb, vars, status, tid, resetRecompile);
// copy vars for later compare
LocalVariableMap oldVars = (LocalVariableMap) vars.clone();
LocalVariableMap varsElse = (LocalVariableMap) vars.clone();
RecompileStatus oldStatus = (RecompileStatus) status.clone();
RecompileStatus statusElse = (RecompileStatus) status.clone();
for (ProgramBlock pb2 : ipb.getChildBlocksIfBody()) rRecompileProgramBlock(pb2, vars, status, tid, resetRecompile);
for (ProgramBlock pb2 : ipb.getChildBlocksElseBody()) rRecompileProgramBlock(pb2, varsElse, statusElse, tid, resetRecompile);
reconcileUpdatedCallVarsIf(oldVars, vars, varsElse, isb);
reconcileUpdatedCallVarsIf(oldStatus, status, statusElse, isb);
removeUpdatedScalars(vars, ipb.getStatementBlock());
} else if (// includes ParFORProgramBlock
pb instanceof ForProgramBlock) {
ForProgramBlock fpb = (ForProgramBlock) pb;
ForStatementBlock fsb = (ForStatementBlock) fpb.getStatementBlock();
// recompile predicates
recompileForPredicates(fpb, fsb, vars, status, tid, resetRecompile);
// remove updated scalars because in loop
removeUpdatedScalars(vars, fpb.getStatementBlock());
// copy vars for later compare
LocalVariableMap oldVars = (LocalVariableMap) vars.clone();
RecompileStatus oldStatus = (RecompileStatus) status.clone();
for (ProgramBlock pb2 : fpb.getChildBlocks()) rRecompileProgramBlock(pb2, vars, status, tid, resetRecompile);
if (reconcileUpdatedCallVarsLoops(oldVars, vars, fsb) | reconcileUpdatedCallVarsLoops(oldStatus, status, fsb)) {
// second pass with unknowns if required
recompileForPredicates(fpb, fsb, vars, status, tid, resetRecompile);
for (ProgramBlock pb2 : fpb.getChildBlocks()) rRecompileProgramBlock(pb2, vars, status, tid, resetRecompile);
}
removeUpdatedScalars(vars, fpb.getStatementBlock());
} else if (// includes ExternalFunctionProgramBlock and ExternalFunctionProgramBlockCP
pb instanceof FunctionProgramBlock) {
// do nothing
} else {
StatementBlock sb = pb.getStatementBlock();
ArrayList<Instruction> tmp = pb.getInstructions();
if (sb == null)
return;
// recompile all for stats propagation and recompile flags
tmp = Recompiler.recompileHopsDag(sb, sb.getHops(), vars, status, true, false, tid);
pb.setInstructions(tmp);
// propagate stats across hops (should be executed on clone of vars)
Recompiler.extractDAGOutputStatistics(sb.getHops(), vars);
// reset recompilation flags (w/ special handling functions)
if (ParForProgramBlock.RESET_RECOMPILATION_FLAGs && !containsRootFunctionOp(sb.getHops()) && resetRecompile.isReset()) {
Hop.resetRecompilationFlag(sb.getHops(), ExecType.CP, resetRecompile);
sb.updateRecompilationFlag();
}
}
}
use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.
the class ProgramRewriter method rRewriteStatementBlocks.
public ArrayList<StatementBlock> rRewriteStatementBlocks(ArrayList<StatementBlock> sbs, ProgramRewriteStatus status, boolean splitDags) {
// ensure robustness for calls from outside
if (status == null)
status = new ProgramRewriteStatus();
// apply rewrite rules to list of statement blocks
List<StatementBlock> tmp = sbs;
for (StatementBlockRewriteRule r : _sbRuleSet) if (splitDags || !r.createsSplitDag())
tmp = r.rewriteStatementBlocks(tmp, status);
// recursively rewrite statement blocks (with potential expansion)
List<StatementBlock> tmp2 = new ArrayList<>();
for (StatementBlock sb : tmp) tmp2.addAll(rRewriteStatementBlock(sb, status, splitDags));
// apply rewrite rules to list of statement blocks (with potential contraction)
for (StatementBlockRewriteRule r : _sbRuleSet) if (splitDags || !r.createsSplitDag())
tmp2 = r.rewriteStatementBlocks(tmp2, status);
// prepare output list
sbs.clear();
sbs.addAll(tmp2);
return sbs;
}
use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.
the class ProgramRewriter method rRewriteStatementBlockHopDAGs.
public void rRewriteStatementBlockHopDAGs(StatementBlock current, ProgramRewriteStatus state) {
// ensure robustness for calls from outside
if (state == null)
state = new ProgramRewriteStatus();
if (current instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) current;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (StatementBlock sb : fstmt.getBody()) rRewriteStatementBlockHopDAGs(sb, state);
} else if (current instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) current;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
wsb.setPredicateHops(rewriteHopDAG(wsb.getPredicateHops(), state));
for (StatementBlock sb : wstmt.getBody()) rRewriteStatementBlockHopDAGs(sb, state);
} else if (current instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) current;
IfStatement istmt = (IfStatement) isb.getStatement(0);
isb.setPredicateHops(rewriteHopDAG(isb.getPredicateHops(), state));
for (StatementBlock sb : istmt.getIfBody()) rRewriteStatementBlockHopDAGs(sb, state);
for (StatementBlock sb : istmt.getElseBody()) rRewriteStatementBlockHopDAGs(sb, state);
} else if (// incl parfor
current instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) current;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
fsb.setFromHops(rewriteHopDAG(fsb.getFromHops(), state));
fsb.setToHops(rewriteHopDAG(fsb.getToHops(), state));
fsb.setIncrementHops(rewriteHopDAG(fsb.getIncrementHops(), state));
for (StatementBlock sb : fstmt.getBody()) rRewriteStatementBlockHopDAGs(sb, state);
} else // generic (last-level)
{
current.setHops(rewriteHopDAG(current.getHops(), state));
}
}
use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.
the class ProgramRewriter method rewriteProgramHopDAGs.
public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp, boolean splitDags) {
ProgramRewriteStatus state = new ProgramRewriteStatus();
// for each namespace, handle function statement blocks
for (String namespaceKey : dmlp.getNamespaces().keySet()) for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey, fname);
rRewriteStatementBlockHopDAGs(fsblock, state);
if (!_sbRuleSet.isEmpty())
rRewriteStatementBlock(fsblock, state, splitDags);
}
// handle regular statement blocks in "main" method
for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) {
StatementBlock current = dmlp.getStatementBlock(i);
rRewriteStatementBlockHopDAGs(current, state);
}
if (!_sbRuleSet.isEmpty())
dmlp.setStatementBlocks(rRewriteStatementBlocks(dmlp.getStatementBlocks(), state, splitDags));
return state;
}
Aggregations