use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.
the class FunctionCallGraph method constructFunctionCallGraph.
private boolean constructFunctionCallGraph(DMLProgram prog) {
if (!prog.hasFunctionStatementBlocks())
// early abort if prog without functions
return false;
boolean ret = false;
try {
Stack<String> fstack = new Stack<>();
HashSet<String> lfset = new HashSet<>();
_fGraph.put(MAIN_FUNCTION_KEY, new HashSet<String>());
for (StatementBlock sblk : prog.getStatementBlocks()) ret |= rConstructFunctionCallGraph(MAIN_FUNCTION_KEY, sblk, fstack, lfset);
} catch (HopsException ex) {
throw new RuntimeException(ex);
}
return ret;
}
use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.
the class IPAPassFlagFunctionsRecompileOnce method rFlagFunctionForRecompileOnce.
/**
* Returns true if this statementblock requires recompilation inside a
* loop statement block.
*
* @param sb statement block
* @param inLoop true if in loop
* @return true if statement block requires recompilation inside a loop statement block
*/
public boolean rFlagFunctionForRecompileOnce(StatementBlock sb, boolean inLoop) {
boolean ret = false;
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (StatementBlock c : fstmt.getBody()) ret |= rFlagFunctionForRecompileOnce(c, inLoop);
} else if (sb instanceof WhileStatementBlock) {
// recompilation information not available at this point
// hence, mark any loop statement block
ret = true;
} else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
ret |= (inLoop && isb.requiresPredicateRecompilation());
for (StatementBlock c : istmt.getIfBody()) ret |= rFlagFunctionForRecompileOnce(c, inLoop);
for (StatementBlock c : istmt.getElseBody()) ret |= rFlagFunctionForRecompileOnce(c, inLoop);
} else if (sb instanceof ForStatementBlock) {
// recompilation information not available at this point
// hence, mark any loop statement block
ret = true;
} else {
ret |= (inLoop && sb.requiresRecompilation());
}
return ret;
}
use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.
the class IPAPassInlineFunctions method rewriteProgram.
@Override
public void rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
for (String fkey : fgraph.getReachableFunctions()) {
FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fkey);
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
if (fstmt.getBody().size() == 1 && HopRewriteUtils.isLastLevelStatementBlock(fstmt.getBody().get(0)) && !containsFunctionOp(fstmt.getBody().get(0).getHops()) && (fgraph.getFunctionCalls(fkey).size() == 1 || countOperators(fstmt.getBody().get(0).getHops()) <= InterProceduralAnalysis.INLINING_MAX_NUM_OPS)) {
if (LOG.isDebugEnabled())
LOG.debug("IPA: Inline function '" + fkey + "'");
// replace all relevant function calls
ArrayList<Hop> hops = fstmt.getBody().get(0).getHops();
List<FunctionOp> fcalls = fgraph.getFunctionCalls(fkey);
List<StatementBlock> fcallsSB = fgraph.getFunctionCallsSB(fkey);
for (int i = 0; i < fcalls.size(); i++) {
FunctionOp op = fcalls.get(i);
// step 0: robustness for special cases
if (op.getInput().size() != fstmt.getInputParams().size() || op.getOutputVariableNames().length != fstmt.getOutputParams().size())
continue;
// step 1: deep copy hop dag
ArrayList<Hop> hops2 = Recompiler.deepCopyHopsDag(hops);
// step 2: replace inputs
HashMap<String, Hop> inMap = new HashMap<>();
for (int j = 0; j < op.getInput().size(); j++) inMap.put(fstmt.getInputParams().get(j).getName(), op.getInput().get(j));
replaceTransientReads(hops2, inMap);
// step 3: replace outputs
HashMap<String, String> outMap = new HashMap<>();
String[] opOutputs = op.getOutputVariableNames();
for (int j = 0; j < opOutputs.length; j++) outMap.put(fstmt.getOutputParams().get(j).getName(), opOutputs[j]);
for (int j = 0; j < hops2.size(); j++) {
Hop out = hops2.get(j);
if (HopRewriteUtils.isData(out, DataOpTypes.TRANSIENTWRITE))
out.setName(outMap.get(out.getName()));
}
fcallsSB.get(i).getHops().remove(op);
fcallsSB.get(i).getHops().addAll(hops2);
}
// update the function call graph to avoid repeated inlining
// (and thus op replication) on repeated IPA calls
fgraph.removeFunctionCalls(fkey);
}
}
}
use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.
the class IPAPassPropagateReplaceLiterals method rReplaceLiterals.
private void rReplaceLiterals(StatementBlock sb, LocalVariableMap constants) {
// remove updated literals
for (String varname : sb.variablesUpdated().getVariableNames()) if (constants.keySet().contains(varname))
constants.remove(varname);
// propagate and replace literals
if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement ws = (WhileStatement) sb.getStatement(0);
replaceLiterals(wsb.getPredicateHops(), constants);
for (StatementBlock current : ws.getBody()) rReplaceLiterals(current, constants);
} else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement ifs = (IfStatement) sb.getStatement(0);
replaceLiterals(isb.getPredicateHops(), constants);
for (StatementBlock current : ifs.getIfBody()) rReplaceLiterals(current, constants);
for (StatementBlock current : ifs.getElseBody()) rReplaceLiterals(current, constants);
} else if (sb instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fs = (ForStatement) sb.getStatement(0);
replaceLiterals(fsb.getFromHops(), constants);
replaceLiterals(fsb.getToHops(), constants);
replaceLiterals(fsb.getIncrementHops(), constants);
for (StatementBlock current : fs.getBody()) rReplaceLiterals(current, constants);
} else {
replaceLiterals(sb.getHops(), constants);
}
}
use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.
the class IPAPassRemoveUnnecessaryCheckpoints method removeCheckpointBeforeUpdate.
private static void removeCheckpointBeforeUpdate(DMLProgram dmlp) {
// approach: scan over top-level program (guaranteed to be unconditional),
// collect checkpoints; determine if used before update; remove first checkpoint
// on second checkpoint if update in between and not used before update
HashMap<String, Hop> chkpointCand = new HashMap<>();
for (StatementBlock sb : dmlp.getStatementBlocks()) {
// prune candidates (used before updated)
Set<String> cands = new HashSet<>(chkpointCand.keySet());
for (String cand : cands) if (sb.variablesRead().containsVariable(cand) && !sb.variablesUpdated().containsVariable(cand)) {
// note: variableRead might include false positives due to meta
// data operations like nrow(X) or operations removed by rewrites
// double check hops on basic blocks; otherwise worst-case
boolean skipRemove = false;
if (sb.getHops() != null) {
Hop.resetVisitStatus(sb.getHops());
skipRemove = true;
for (Hop root : sb.getHops()) skipRemove &= !HopRewriteUtils.rContainsRead(root, cand, false);
}
if (!skipRemove)
chkpointCand.remove(cand);
}
// prune candidates (updated in conditional control flow)
Set<String> cands2 = new HashSet<>(chkpointCand.keySet());
if (sb instanceof IfStatementBlock || sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock) {
for (String cand : cands2) if (sb.variablesUpdated().containsVariable(cand)) {
chkpointCand.remove(cand);
}
} else // prune candidates (updated w/ multiple reads)
{
for (String cand : cands2) if (sb.variablesUpdated().containsVariable(cand) && sb.getHops() != null) {
Hop.resetVisitStatus(sb.getHops());
for (Hop root : sb.getHops()) if (root.getName().equals(cand) && !HopRewriteUtils.rHasSimpleReadChain(root, cand)) {
chkpointCand.remove(cand);
}
}
}
// collect checkpoints and remove unnecessary checkpoints
if (HopRewriteUtils.isLastLevelStatementBlock(sb)) {
ArrayList<Hop> tmp = collectCheckpoints(sb.getHops());
for (Hop chkpoint : tmp) {
if (chkpointCand.containsKey(chkpoint.getName())) {
chkpointCand.get(chkpoint.getName()).setRequiresCheckpoint(false);
}
chkpointCand.put(chkpoint.getName(), chkpoint);
}
}
}
}
Aggregations