use of org.apache.sysml.parser.FunctionStatementBlock in project incubator-systemml by apache.
the class InterProceduralAnalysis method getFunctionCandidatesForStatisticPropagation.
private void getFunctionCandidatesForStatisticPropagation(DMLProgram prog, Hop hop, Map<String, Integer> fcandCounts, Map<String, FunctionOp> fcandHops) throws HopsException, ParseException {
if (hop.isVisited())
return;
if (hop instanceof FunctionOp && !((FunctionOp) hop).getFunctionNamespace().equals(DMLProgram.INTERNAL_NAMESPACE)) {
//maintain counters and investigate functions if not seen so far
FunctionOp fop = (FunctionOp) hop;
String fkey = DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), fop.getFunctionName());
if (fcandCounts.containsKey(fkey)) {
if (ALLOW_MULTIPLE_FUNCTION_CALLS) {
//compare input matrix characteristics for both function calls
//(if unknown or difference: maintain counter - this function is no candidate)
boolean consistent = true;
FunctionOp efop = fcandHops.get(fkey);
int numInputs = efop.getInput().size();
for (int i = 0; i < numInputs; i++) {
Hop h1 = efop.getInput().get(i);
Hop h2 = fop.getInput().get(i);
//check matrix and scalar sizes (if known dims, nnz known/unknown,
// safeness of nnz propagation, determined later per input)
consistent &= (h1.dimsKnown() && h2.dimsKnown() && h1.getDim1() == h2.getDim1() && h1.getDim2() == h2.getDim2() && h1.getNnz() == h2.getNnz());
//check literal values (equi value)
if (h1 instanceof LiteralOp) {
consistent &= (h2 instanceof LiteralOp && HopRewriteUtils.isEqualValue((LiteralOp) h1, (LiteralOp) h2));
}
}
if (//if differences, do not propagate
!consistent)
fcandCounts.put(fkey, fcandCounts.get(fkey) + 1);
} else {
//maintain counter (this function is no candidate)
fcandCounts.put(fkey, fcandCounts.get(fkey) + 1);
}
} else {
//first appearance
//create a new count entry
fcandCounts.put(fkey, 1);
//keep the function call hop
fcandHops.put(fkey, fop);
FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fop.getFunctionNamespace(), fop.getFunctionName());
getFunctionCandidatesForStatisticPropagation(fsb, fcandCounts, fcandHops);
}
}
for (Hop c : hop.getInput()) getFunctionCandidatesForStatisticPropagation(prog, c, fcandCounts, fcandHops);
hop.setVisited();
}
use of org.apache.sysml.parser.FunctionStatementBlock in project incubator-systemml by apache.
the class InterProceduralAnalysis method rFlagFunctionForRecompileOnce.
/**
* Returns true if this statementblock requires recompilation inside a
* loop statement block.
*
* @param sb statement block
* @param inLoop true if in loop
* @return true if statement block requires recompilation inside a loop statement block
*/
public boolean rFlagFunctionForRecompileOnce(StatementBlock sb, boolean inLoop) {
boolean ret = false;
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (StatementBlock c : fstmt.getBody()) ret |= rFlagFunctionForRecompileOnce(c, inLoop);
} else if (sb instanceof WhileStatementBlock) {
//recompilation information not available at this point
ret = true;
/*
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
ret |= (inLoop && wsb.requiresPredicateRecompilation() );
for( StatementBlock c : wstmt.getBody() )
ret |= rFlagFunctionForRecompileOnce( c, true );
*/
} else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
ret |= (inLoop && isb.requiresPredicateRecompilation());
for (StatementBlock c : istmt.getIfBody()) ret |= rFlagFunctionForRecompileOnce(c, inLoop);
for (StatementBlock c : istmt.getElseBody()) ret |= rFlagFunctionForRecompileOnce(c, inLoop);
} else if (sb instanceof ForStatementBlock) {
//recompilation information not available at this point
ret = true;
/*
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement)fsb.getStatement(0);
for( StatementBlock c : fstmt.getBody() )
ret |= rFlagFunctionForRecompileOnce( c, true );
*/
} else {
ret |= (inLoop && sb.requiresRecompilation());
}
return ret;
}
use of org.apache.sysml.parser.FunctionStatementBlock in project incubator-systemml by apache.
the class InterProceduralAnalysis method propagateStatisticsAcrossBlock.
/////////////////////////////
// INTRA-PROCEDURE ANALYSIS
//////
/**
* Perform intra-procedural analysis (IPA) by propagating statistics
* across statement blocks.
*
* @param sb DML statement blocks.
* @param fcand Function candidates.
* @param callVars Map of variables eligible for propagation.
* @param fcandSafeNNZ Function candidate safe non-zeros.
* @param unaryFcands Unary function candidates.
* @param fnStack Function stack to determine current scope.
* @throws HopsException If a HopsException occurs.
* @throws ParseException If a ParseException occurs.
*/
private void propagateStatisticsAcrossBlock(StatementBlock sb, Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> fcandSafeNNZ, Set<String> unaryFcands, Set<String> fnStack) throws HopsException, ParseException {
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
} else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
//old stats into predicate
propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
//remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, wsb);
//check and propagate stats into body
LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, wsb)) {
//second pass if required
propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
}
//remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
} else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
//old stats into predicate
propagateStatisticsAcrossPredicateDAG(isb.getPredicateHops(), callVars);
//check and propagate stats into body
LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
LocalVariableMap callVarsElse = (LocalVariableMap) callVars.clone();
for (StatementBlock sbi : istmt.getIfBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
for (StatementBlock sbi : istmt.getElseBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVarsElse, fcandSafeNNZ, unaryFcands, fnStack);
callVars = Recompiler.reconcileUpdatedCallVarsIf(oldCallVars, callVars, callVarsElse, isb);
//remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
} else if (//incl parfor
sb instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
//old stats into predicate
propagateStatisticsAcrossPredicateDAG(fsb.getFromHops(), callVars);
propagateStatisticsAcrossPredicateDAG(fsb.getToHops(), callVars);
propagateStatisticsAcrossPredicateDAG(fsb.getIncrementHops(), callVars);
//remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, fsb);
//check and propagate stats into body
LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, fsb))
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
//remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
} else //generic (last-level)
{
//remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
//old stats in, new stats out if updated
ArrayList<Hop> roots = sb.get_hops();
DMLProgram prog = sb.getDMLProg();
//replace scalar reads with literals
Hop.resetVisitStatus(roots);
propagateScalarsAcrossDAG(roots, callVars);
//refresh stats across dag
Hop.resetVisitStatus(roots);
propagateStatisticsAcrossDAG(roots, callVars);
//propagate stats into function calls
Hop.resetVisitStatus(roots);
propagateStatisticsIntoFunctions(prog, roots, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
}
}
Aggregations