Search in sources :

Example 31 with FunctionStatementBlock

use of org.apache.sysml.parser.FunctionStatementBlock in project incubator-systemml by apache.

the class InterProceduralAnalysis method getFunctionCandidatesForStatisticPropagation.

private void getFunctionCandidatesForStatisticPropagation(DMLProgram prog, Hop hop, Map<String, Integer> fcandCounts, Map<String, FunctionOp> fcandHops) throws HopsException, ParseException {
    if (hop.isVisited())
        return;
    if (hop instanceof FunctionOp && !((FunctionOp) hop).getFunctionNamespace().equals(DMLProgram.INTERNAL_NAMESPACE)) {
        //maintain counters and investigate functions if not seen so far
        FunctionOp fop = (FunctionOp) hop;
        String fkey = DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), fop.getFunctionName());
        if (fcandCounts.containsKey(fkey)) {
            if (ALLOW_MULTIPLE_FUNCTION_CALLS) {
                //compare input matrix characteristics for both function calls
                //(if unknown or difference: maintain counter - this function is no candidate)
                boolean consistent = true;
                FunctionOp efop = fcandHops.get(fkey);
                int numInputs = efop.getInput().size();
                for (int i = 0; i < numInputs; i++) {
                    Hop h1 = efop.getInput().get(i);
                    Hop h2 = fop.getInput().get(i);
                    //check matrix and scalar sizes (if known dims, nnz known/unknown, 
                    // safeness of nnz propagation, determined later per input)
                    consistent &= (h1.dimsKnown() && h2.dimsKnown() && h1.getDim1() == h2.getDim1() && h1.getDim2() == h2.getDim2() && h1.getNnz() == h2.getNnz());
                    //check literal values (equi value)
                    if (h1 instanceof LiteralOp) {
                        consistent &= (h2 instanceof LiteralOp && HopRewriteUtils.isEqualValue((LiteralOp) h1, (LiteralOp) h2));
                    }
                }
                if (//if differences, do not propagate
                !consistent)
                    fcandCounts.put(fkey, fcandCounts.get(fkey) + 1);
            } else {
                //maintain counter (this function is no candidate)
                fcandCounts.put(fkey, fcandCounts.get(fkey) + 1);
            }
        } else {
            //first appearance
            //create a new count entry
            fcandCounts.put(fkey, 1);
            //keep the function call hop
            fcandHops.put(fkey, fop);
            FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fop.getFunctionNamespace(), fop.getFunctionName());
            getFunctionCandidatesForStatisticPropagation(fsb, fcandCounts, fcandHops);
        }
    }
    for (Hop c : hop.getInput()) getFunctionCandidatesForStatisticPropagation(prog, c, fcandCounts, fcandHops);
    hop.setVisited();
}
Also used : FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) Hop(org.apache.sysml.hops.Hop) FunctionOp(org.apache.sysml.hops.FunctionOp) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 32 with FunctionStatementBlock

use of org.apache.sysml.parser.FunctionStatementBlock in project incubator-systemml by apache.

the class InterProceduralAnalysis method rFlagFunctionForRecompileOnce.

/**
	 * Returns true if this statementblock requires recompilation inside a 
	 * loop statement block.
	 * 
	 * @param sb statement block
	 * @param inLoop true if in loop
	 * @return true if statement block requires recompilation inside a loop statement block
	 */
public boolean rFlagFunctionForRecompileOnce(StatementBlock sb, boolean inLoop) {
    boolean ret = false;
    if (sb instanceof FunctionStatementBlock) {
        FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
        FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
        for (StatementBlock c : fstmt.getBody()) ret |= rFlagFunctionForRecompileOnce(c, inLoop);
    } else if (sb instanceof WhileStatementBlock) {
        //recompilation information not available at this point
        ret = true;
    /*
			WhileStatementBlock wsb = (WhileStatementBlock) sb;
			WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
			ret |= (inLoop && wsb.requiresPredicateRecompilation() );
			for( StatementBlock c : wstmt.getBody() )
				ret |= rFlagFunctionForRecompileOnce( c, true );
			*/
    } else if (sb instanceof IfStatementBlock) {
        IfStatementBlock isb = (IfStatementBlock) sb;
        IfStatement istmt = (IfStatement) isb.getStatement(0);
        ret |= (inLoop && isb.requiresPredicateRecompilation());
        for (StatementBlock c : istmt.getIfBody()) ret |= rFlagFunctionForRecompileOnce(c, inLoop);
        for (StatementBlock c : istmt.getElseBody()) ret |= rFlagFunctionForRecompileOnce(c, inLoop);
    } else if (sb instanceof ForStatementBlock) {
        //recompilation information not available at this point
        ret = true;
    /* 
			ForStatementBlock fsb = (ForStatementBlock) sb;
			ForStatement fstmt = (ForStatement)fsb.getStatement(0);
			for( StatementBlock c : fstmt.getBody() )
				ret |= rFlagFunctionForRecompileOnce( c, true );
			*/
    } else {
        ret |= (inLoop && sb.requiresRecompilation());
    }
    return ret;
}
Also used : ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) ExternalFunctionStatement(org.apache.sysml.parser.ExternalFunctionStatement) FunctionStatement(org.apache.sysml.parser.FunctionStatement) IfStatement(org.apache.sysml.parser.IfStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 33 with FunctionStatementBlock

use of org.apache.sysml.parser.FunctionStatementBlock in project incubator-systemml by apache.

the class InterProceduralAnalysis method propagateStatisticsAcrossBlock.

/////////////////////////////
// INTRA-PROCEDURE ANALYSIS
//////	
/**
	 * Perform intra-procedural analysis (IPA) by propagating statistics
	 * across statement blocks.
	 *
	 * @param sb  DML statement blocks.
	 * @param fcand  Function candidates.
	 * @param callVars  Map of variables eligible for propagation.
	 * @param fcandSafeNNZ  Function candidate safe non-zeros.
	 * @param unaryFcands  Unary function candidates.
	 * @param fnStack  Function stack to determine current scope.
	 * @throws HopsException  If a HopsException occurs.
	 * @throws ParseException  If a ParseException occurs.
	 */
private void propagateStatisticsAcrossBlock(StatementBlock sb, Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> fcandSafeNNZ, Set<String> unaryFcands, Set<String> fnStack) throws HopsException, ParseException {
    if (sb instanceof FunctionStatementBlock) {
        FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
        FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
        for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
    } else if (sb instanceof WhileStatementBlock) {
        WhileStatementBlock wsb = (WhileStatementBlock) sb;
        WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
        //old stats into predicate
        propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
        //remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, wsb);
        //check and propagate stats into body
        LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
        for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
        if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, wsb)) {
            //second pass if required
            propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
            for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
        }
        //remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, sb);
    } else if (sb instanceof IfStatementBlock) {
        IfStatementBlock isb = (IfStatementBlock) sb;
        IfStatement istmt = (IfStatement) isb.getStatement(0);
        //old stats into predicate
        propagateStatisticsAcrossPredicateDAG(isb.getPredicateHops(), callVars);
        //check and propagate stats into body
        LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
        LocalVariableMap callVarsElse = (LocalVariableMap) callVars.clone();
        for (StatementBlock sbi : istmt.getIfBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
        for (StatementBlock sbi : istmt.getElseBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVarsElse, fcandSafeNNZ, unaryFcands, fnStack);
        callVars = Recompiler.reconcileUpdatedCallVarsIf(oldCallVars, callVars, callVarsElse, isb);
        //remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, sb);
    } else if (//incl parfor
    sb instanceof ForStatementBlock) {
        ForStatementBlock fsb = (ForStatementBlock) sb;
        ForStatement fstmt = (ForStatement) fsb.getStatement(0);
        //old stats into predicate
        propagateStatisticsAcrossPredicateDAG(fsb.getFromHops(), callVars);
        propagateStatisticsAcrossPredicateDAG(fsb.getToHops(), callVars);
        propagateStatisticsAcrossPredicateDAG(fsb.getIncrementHops(), callVars);
        //remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, fsb);
        //check and propagate stats into body
        LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
        for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
        if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, fsb))
            for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
        //remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, sb);
    } else //generic (last-level)
    {
        //remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, sb);
        //old stats in, new stats out if updated
        ArrayList<Hop> roots = sb.get_hops();
        DMLProgram prog = sb.getDMLProg();
        //replace scalar reads with literals
        Hop.resetVisitStatus(roots);
        propagateScalarsAcrossDAG(roots, callVars);
        //refresh stats across dag
        Hop.resetVisitStatus(roots);
        propagateStatisticsAcrossDAG(roots, callVars);
        //propagate stats into function calls
        Hop.resetVisitStatus(roots);
        propagateStatisticsIntoFunctions(prog, roots, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
    }
}
Also used : ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) ExternalFunctionStatement(org.apache.sysml.parser.ExternalFunctionStatement) FunctionStatement(org.apache.sysml.parser.FunctionStatement) IfStatement(org.apache.sysml.parser.IfStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) LocalVariableMap(org.apache.sysml.runtime.controlprogram.LocalVariableMap) ArrayList(java.util.ArrayList) DMLProgram(org.apache.sysml.parser.DMLProgram) WhileStatement(org.apache.sysml.parser.WhileStatement) ForStatement(org.apache.sysml.parser.ForStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Aggregations

FunctionStatementBlock (org.apache.sysml.parser.FunctionStatementBlock)33 FunctionStatement (org.apache.sysml.parser.FunctionStatement)24 StatementBlock (org.apache.sysml.parser.StatementBlock)24 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)22 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)22 WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)22 IfStatement (org.apache.sysml.parser.IfStatement)14 ForStatement (org.apache.sysml.parser.ForStatement)12 ParForStatementBlock (org.apache.sysml.parser.ParForStatementBlock)12 WhileStatement (org.apache.sysml.parser.WhileStatement)12 ExternalFunctionStatement (org.apache.sysml.parser.ExternalFunctionStatement)11 Hop (org.apache.sysml.hops.Hop)10 ArrayList (java.util.ArrayList)9 FunctionOp (org.apache.sysml.hops.FunctionOp)9 DMLProgram (org.apache.sysml.parser.DMLProgram)7 LocalVariableMap (org.apache.sysml.runtime.controlprogram.LocalVariableMap)6 LanguageException (org.apache.sysml.parser.LanguageException)4 FunctionProgramBlock (org.apache.sysml.runtime.controlprogram.FunctionProgramBlock)4 HashSet (java.util.HashSet)3 LiteralOp (org.apache.sysml.hops.LiteralOp)3