use of org.apache.sysml.parser.FunctionStatement in project incubator-systemml by apache.
the class InterProceduralAnalysis method propagateStatisticsAcrossBlock.
/////////////////////////////
// INTRA-PROCEDURE ANALYSIS
//////
/**
* Perform intra-procedural analysis (IPA) by propagating statistics
* across statement blocks.
*
* @param sb DML statement blocks.
* @param fcand Function candidates.
* @param callVars Map of variables eligible for propagation.
* @param fcandSafeNNZ Function candidate safe non-zeros.
* @param unaryFcands Unary function candidates.
* @param fnStack Function stack to determine current scope.
* @throws HopsException If a HopsException occurs.
* @throws ParseException If a ParseException occurs.
*/
private void propagateStatisticsAcrossBlock(StatementBlock sb, Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> fcandSafeNNZ, Set<String> unaryFcands, Set<String> fnStack) throws HopsException, ParseException {
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
} else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
//old stats into predicate
propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
//remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, wsb);
//check and propagate stats into body
LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, wsb)) {
//second pass if required
propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
}
//remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
} else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
//old stats into predicate
propagateStatisticsAcrossPredicateDAG(isb.getPredicateHops(), callVars);
//check and propagate stats into body
LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
LocalVariableMap callVarsElse = (LocalVariableMap) callVars.clone();
for (StatementBlock sbi : istmt.getIfBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
for (StatementBlock sbi : istmt.getElseBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVarsElse, fcandSafeNNZ, unaryFcands, fnStack);
callVars = Recompiler.reconcileUpdatedCallVarsIf(oldCallVars, callVars, callVarsElse, isb);
//remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
} else if (//incl parfor
sb instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
//old stats into predicate
propagateStatisticsAcrossPredicateDAG(fsb.getFromHops(), callVars);
propagateStatisticsAcrossPredicateDAG(fsb.getToHops(), callVars);
propagateStatisticsAcrossPredicateDAG(fsb.getIncrementHops(), callVars);
//remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, fsb);
//check and propagate stats into body
LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, fsb))
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
//remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
} else //generic (last-level)
{
//remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
//old stats in, new stats out if updated
ArrayList<Hop> roots = sb.get_hops();
DMLProgram prog = sb.getDMLProg();
//replace scalar reads with literals
Hop.resetVisitStatus(roots);
propagateScalarsAcrossDAG(roots, callVars);
//refresh stats across dag
Hop.resetVisitStatus(roots);
propagateStatisticsAcrossDAG(roots, callVars);
//propagate stats into function calls
Hop.resetVisitStatus(roots);
propagateStatisticsIntoFunctions(prog, roots, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
}
}
Aggregations