use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.
the class ResourceOptimizer method pruneHasOnlyUnknownMR.
private static boolean pruneHasOnlyUnknownMR(ProgramBlock pb) {
if (pb instanceof WhileProgramBlock) {
WhileStatementBlock sb = (WhileStatementBlock) pb.getStatementBlock();
sb.getPredicateHops().resetVisitStatus();
return pruneHasOnlyUnknownMR(sb.getPredicateHops());
} else if (pb instanceof IfProgramBlock) {
IfStatementBlock sb = (IfStatementBlock) pb.getStatementBlock();
sb.getPredicateHops().resetVisitStatus();
return pruneHasOnlyUnknownMR(sb.getPredicateHops());
} else if (// incl parfor
pb instanceof ForProgramBlock) {
ForStatementBlock sb = (ForStatementBlock) pb.getStatementBlock();
sb.getFromHops().resetVisitStatus();
sb.getToHops().resetVisitStatus();
sb.getIncrementHops().resetVisitStatus();
return pruneHasOnlyUnknownMR(sb.getFromHops()) && pruneHasOnlyUnknownMR(sb.getToHops()) && pruneHasOnlyUnknownMR(sb.getIncrementHops());
} else // last-level program blocks
{
StatementBlock sb = pb.getStatementBlock();
return pruneHasOnlyUnknownMR(sb.getHops());
}
}
use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.
the class GraphBuilder method constructGDFGraph.
@SuppressWarnings("unchecked")
private static void constructGDFGraph(ProgramBlock pb, HashMap<String, GDFNode> roots) throws DMLRuntimeException, HopsException {
if (pb instanceof FunctionProgramBlock) {
throw new DMLRuntimeException("FunctionProgramBlocks not implemented yet.");
} else if (pb instanceof WhileProgramBlock) {
WhileProgramBlock wpb = (WhileProgramBlock) pb;
WhileStatementBlock wsb = (WhileStatementBlock) pb.getStatementBlock();
//construct predicate node (conceptually sequence of from/to/incr)
GDFNode pred = constructGDFGraph(wsb.getPredicateHops(), wpb, new HashMap<Long, GDFNode>(), roots);
HashMap<String, GDFNode> inputs = constructLoopInputNodes(wpb, wsb, roots);
HashMap<String, GDFNode> lroots = (HashMap<String, GDFNode>) inputs.clone();
//process childs blocks
for (ProgramBlock pbc : wpb.getChildBlocks()) constructGDFGraph(pbc, lroots);
HashMap<String, GDFNode> outputs = constructLoopOutputNodes(wsb, lroots);
GDFLoopNode lnode = new GDFLoopNode(wpb, pred, inputs, outputs);
//construct crossblock nodes
constructLoopOutputCrossBlockNodes(wsb, lnode, outputs, roots, wpb);
} else if (pb instanceof IfProgramBlock) {
IfProgramBlock ipb = (IfProgramBlock) pb;
IfStatementBlock isb = (IfStatementBlock) pb.getStatementBlock();
//construct predicate
if (isb.getPredicateHops() != null) {
Hop pred = isb.getPredicateHops();
roots.put(pred.getName(), constructGDFGraph(pred, ipb, new HashMap<Long, GDFNode>(), roots));
}
//construct if and else branch separately
HashMap<String, GDFNode> ifRoots = (HashMap<String, GDFNode>) roots.clone();
HashMap<String, GDFNode> elseRoots = (HashMap<String, GDFNode>) roots.clone();
for (ProgramBlock pbc : ipb.getChildBlocksIfBody()) constructGDFGraph(pbc, ifRoots);
if (ipb.getChildBlocksElseBody() != null)
for (ProgramBlock pbc : ipb.getChildBlocksElseBody()) constructGDFGraph(pbc, elseRoots);
//merge data flow roots (if no else, elseRoots refer to original roots)
reconcileMergeIfProgramBlockOutputs(ifRoots, elseRoots, roots, ipb);
} else if (//incl parfor
pb instanceof ForProgramBlock) {
ForProgramBlock fpb = (ForProgramBlock) pb;
ForStatementBlock fsb = (ForStatementBlock) pb.getStatementBlock();
//construct predicate node (conceptually sequence of from/to/incr)
GDFNode pred = constructForPredicateNode(fpb, fsb, roots);
HashMap<String, GDFNode> inputs = constructLoopInputNodes(fpb, fsb, roots);
HashMap<String, GDFNode> lroots = (HashMap<String, GDFNode>) inputs.clone();
//process childs blocks
for (ProgramBlock pbc : fpb.getChildBlocks()) constructGDFGraph(pbc, lroots);
HashMap<String, GDFNode> outputs = constructLoopOutputNodes(fsb, lroots);
GDFLoopNode lnode = new GDFLoopNode(fpb, pred, inputs, outputs);
//construct crossblock nodes
constructLoopOutputCrossBlockNodes(fsb, lnode, outputs, roots, fpb);
} else //last-level program block
{
StatementBlock sb = pb.getStatementBlock();
ArrayList<Hop> hops = sb.get_hops();
if (hops != null) {
//create new local memo structure for local dag
HashMap<Long, GDFNode> lmemo = new HashMap<Long, GDFNode>();
for (Hop hop : hops) {
//recursively construct GDF graph for hop dag root
GDFNode root = constructGDFGraph(hop, pb, lmemo, roots);
if (root == null)
throw new HopsException("GDFGraphBuilder: failed to constuct dag root for: " + Explain.explain(hop));
//create cross block nodes for all transient writes
if (hop instanceof DataOp && ((DataOp) hop).getDataOpType() == DataOpTypes.TRANSIENTWRITE)
root = new GDFCrossBlockNode(hop, pb, root, hop.getName());
//add GDF root node to global roots
roots.put(hop.getName(), root);
}
}
}
}
use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.
the class InterProceduralAnalysis method analyzeProgram.
/**
* Public interface to perform IPA over a given DML program.
*
* @param dmlp the dml program
* @throws HopsException if HopsException occurs
* @throws ParseException if ParseException occurs
* @throws LanguageException if LanguageException occurs
*/
public void analyzeProgram(DMLProgram dmlp) throws HopsException, ParseException, LanguageException {
FunctionCallGraph fgraph = new FunctionCallGraph(dmlp);
//step 1: get candidates for statistics propagation into functions (if required)
Map<String, Integer> fcandCounts = new HashMap<String, Integer>();
Map<String, FunctionOp> fcandHops = new HashMap<String, FunctionOp>();
Map<String, Set<Long>> fcandSafeNNZ = new HashMap<String, Set<Long>>();
if (!dmlp.getFunctionStatementBlocks().isEmpty()) {
for (//get candidates (over entire program)
StatementBlock sb : //get candidates (over entire program)
dmlp.getStatementBlocks()) getFunctionCandidatesForStatisticPropagation(sb, fcandCounts, fcandHops);
pruneFunctionCandidatesForStatisticPropagation(fcandCounts, fcandHops);
determineFunctionCandidatesNNZPropagation(fcandHops, fcandSafeNNZ);
DMLTranslator.resetHopsDAGVisitStatus(dmlp);
}
//step 2: get unary dimension-preserving non-candidate functions
Collection<String> unaryFcandTmp = fgraph.getReachableFunctions(fcandCounts.keySet());
HashSet<String> unaryFcands = new HashSet<String>();
if (!unaryFcandTmp.isEmpty() && UNARY_DIMS_PRESERVING_FUNS) {
for (String tmp : unaryFcandTmp) if (isUnarySizePreservingFunction(dmlp.getFunctionStatementBlock(tmp)))
unaryFcands.add(tmp);
}
//step 3: propagate statistics and scalars into functions and across DAGs
if (!fcandCounts.isEmpty() || INTRA_PROCEDURAL_ANALYSIS) {
//(callVars used to chain outputs/inputs of multiple functions calls)
LocalVariableMap callVars = new LocalVariableMap();
for (//propagate stats into candidates
StatementBlock sb : //propagate stats into candidates
dmlp.getStatementBlocks()) propagateStatisticsAcrossBlock(sb, fcandCounts, callVars, fcandSafeNNZ, unaryFcands, new HashSet<String>());
}
//step 4: remove unused functions (e.g., inlined or never called)
if (REMOVE_UNUSED_FUNCTIONS) {
removeUnusedFunctions(dmlp, fgraph);
}
//step 5: flag functions with loops for 'recompile-on-entry'
if (FLAG_FUNCTION_RECOMPILE_ONCE) {
flagFunctionsForRecompileOnce(dmlp, fgraph);
}
//step 6: set global data flow properties
if (REMOVE_UNNECESSARY_CHECKPOINTS && OptimizerUtils.isSparkExecutionMode()) {
//remove unnecessary checkpoint before update
removeCheckpointBeforeUpdate(dmlp);
//move necessary checkpoint after update
moveCheckpointAfterUpdate(dmlp);
//remove unnecessary checkpoint read-{write|uagg}
removeCheckpointReadWrite(dmlp);
}
//step 7: remove constant binary ops
if (REMOVE_CONSTANT_BINARY_OPS) {
removeConstantBinaryOps(dmlp);
}
}
use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.
the class InterProceduralAnalysis method rFlagFunctionForRecompileOnce.
/**
* Returns true if this statementblock requires recompilation inside a
* loop statement block.
*
* @param sb statement block
* @param inLoop true if in loop
* @return true if statement block requires recompilation inside a loop statement block
*/
public boolean rFlagFunctionForRecompileOnce(StatementBlock sb, boolean inLoop) {
boolean ret = false;
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (StatementBlock c : fstmt.getBody()) ret |= rFlagFunctionForRecompileOnce(c, inLoop);
} else if (sb instanceof WhileStatementBlock) {
//recompilation information not available at this point
ret = true;
/*
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
ret |= (inLoop && wsb.requiresPredicateRecompilation() );
for( StatementBlock c : wstmt.getBody() )
ret |= rFlagFunctionForRecompileOnce( c, true );
*/
} else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
ret |= (inLoop && isb.requiresPredicateRecompilation());
for (StatementBlock c : istmt.getIfBody()) ret |= rFlagFunctionForRecompileOnce(c, inLoop);
for (StatementBlock c : istmt.getElseBody()) ret |= rFlagFunctionForRecompileOnce(c, inLoop);
} else if (sb instanceof ForStatementBlock) {
//recompilation information not available at this point
ret = true;
/*
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement)fsb.getStatement(0);
for( StatementBlock c : fstmt.getBody() )
ret |= rFlagFunctionForRecompileOnce( c, true );
*/
} else {
ret |= (inLoop && sb.requiresRecompilation());
}
return ret;
}
use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.
the class InterProceduralAnalysis method moveCheckpointAfterUpdate.
private void moveCheckpointAfterUpdate(DMLProgram dmlp) throws HopsException {
//approach: scan over top-level program (guaranteed to be unconditional),
//collect checkpoints; determine if used before update; move first checkpoint
//after update if not used before update (best effort move which often avoids
//the second checkpoint on loops even though used in between)
HashMap<String, Hop> chkpointCand = new HashMap<String, Hop>();
for (StatementBlock sb : dmlp.getStatementBlocks()) {
//prune candidates (used before updated)
Set<String> cands = new HashSet<String>(chkpointCand.keySet());
for (String cand : cands) if (sb.variablesRead().containsVariable(cand) && !sb.variablesUpdated().containsVariable(cand)) {
//note: variableRead might include false positives due to meta
//data operations like nrow(X) or operations removed by rewrites
//double check hops on basic blocks; otherwise worst-case
boolean skipRemove = false;
if (sb.get_hops() != null) {
Hop.resetVisitStatus(sb.get_hops());
skipRemove = true;
for (Hop root : sb.get_hops()) skipRemove &= !HopRewriteUtils.rContainsRead(root, cand, false);
}
if (!skipRemove)
chkpointCand.remove(cand);
}
//prune candidates (updated in conditional control flow)
Set<String> cands2 = new HashSet<String>(chkpointCand.keySet());
if (sb instanceof IfStatementBlock || sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock) {
for (String cand : cands2) if (sb.variablesUpdated().containsVariable(cand)) {
chkpointCand.remove(cand);
}
} else //move checkpoint after update with simple read chain
//(note: right now this only applies if the checkpoints comes from a previous
//statement block, within-dag checkpoints should be handled during injection)
{
for (String cand : cands2) if (sb.variablesUpdated().containsVariable(cand) && sb.get_hops() != null) {
Hop.resetVisitStatus(sb.get_hops());
for (Hop root : sb.get_hops()) if (root.getName().equals(cand)) {
if (HopRewriteUtils.rHasSimpleReadChain(root, cand)) {
chkpointCand.get(cand).setRequiresCheckpoint(false);
root.getInput().get(0).setRequiresCheckpoint(true);
chkpointCand.put(cand, root.getInput().get(0));
} else
chkpointCand.remove(cand);
}
}
}
//collect checkpoints
ArrayList<Hop> tmp = collectCheckpoints(sb.get_hops());
for (Hop chkpoint : tmp) {
chkpointCand.put(chkpoint.getName(), chkpoint);
}
}
}
Aggregations